wilbin commited on
Commit
8896a5f
1 Parent(s): 8d4e1fe

Upload 248 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. samsledje-D-SCRIPT-8a55490/.gitignore +10 -0
  3. samsledje-D-SCRIPT-8a55490/.readthedocs.yml +25 -0
  4. samsledje-D-SCRIPT-8a55490/CHANGELOG.md +30 -0
  5. samsledje-D-SCRIPT-8a55490/LICENSE +21 -0
  6. samsledje-D-SCRIPT-8a55490/README.md +18 -0
  7. samsledje-D-SCRIPT-8a55490/data/pairs/ecoli_test.tsv +0 -0
  8. samsledje-D-SCRIPT-8a55490/data/pairs/fly_test.tsv +0 -0
  9. samsledje-D-SCRIPT-8a55490/data/pairs/human_test.tsv +0 -0
  10. samsledje-D-SCRIPT-8a55490/data/pairs/human_train.tsv +3 -0
  11. samsledje-D-SCRIPT-8a55490/data/pairs/mouse_test.tsv +0 -0
  12. samsledje-D-SCRIPT-8a55490/data/pairs/worm_test.tsv +0 -0
  13. samsledje-D-SCRIPT-8a55490/data/pairs/yeast_test.tsv +0 -0
  14. samsledje-D-SCRIPT-8a55490/data/seqs/ecoli.fasta +0 -0
  15. samsledje-D-SCRIPT-8a55490/data/seqs/fly.fasta +0 -0
  16. samsledje-D-SCRIPT-8a55490/data/seqs/human.fasta +3 -0
  17. samsledje-D-SCRIPT-8a55490/data/seqs/mouse.fasta +3 -0
  18. samsledje-D-SCRIPT-8a55490/data/seqs/worm.fasta +0 -0
  19. samsledje-D-SCRIPT-8a55490/data/seqs/yeast.fasta +0 -0
  20. samsledje-D-SCRIPT-8a55490/docs/.nojekyll +0 -0
  21. samsledje-D-SCRIPT-8a55490/docs/Makefile +20 -0
  22. samsledje-D-SCRIPT-8a55490/docs/build/doctrees/api/dscript.commands.doctree +0 -0
  23. samsledje-D-SCRIPT-8a55490/docs/build/doctrees/api/dscript.models.doctree +0 -0
  24. samsledje-D-SCRIPT-8a55490/docs/build/doctrees/api/index.doctree +0 -0
  25. samsledje-D-SCRIPT-8a55490/docs/build/doctrees/data.doctree +0 -0
  26. samsledje-D-SCRIPT-8a55490/docs/build/doctrees/environment.pickle +3 -0
  27. samsledje-D-SCRIPT-8a55490/docs/build/doctrees/index.doctree +0 -0
  28. samsledje-D-SCRIPT-8a55490/docs/build/doctrees/installation.doctree +0 -0
  29. samsledje-D-SCRIPT-8a55490/docs/build/doctrees/usage.doctree +0 -0
  30. samsledje-D-SCRIPT-8a55490/docs/build/html/.buildinfo +4 -0
  31. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/alphabets.html +276 -0
  32. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/commands/embed.html +237 -0
  33. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/commands/eval.html +371 -0
  34. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/commands/predict.html +344 -0
  35. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/commands/train.html +755 -0
  36. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/fasta.html +279 -0
  37. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/language_model.html +327 -0
  38. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/models/contact.html +331 -0
  39. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/models/embedding.html +369 -0
  40. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/models/interaction.html +413 -0
  41. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/pretrained.html +291 -0
  42. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/utils.html +354 -0
  43. samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/index.html +209 -0
  44. samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/api/dscript.commands.rst.txt +42 -0
  45. samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/api/dscript.models.rst.txt +26 -0
  46. samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/api/index.rst.txt +48 -0
  47. samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/data.rst.txt +43 -0
  48. samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/index.rst.txt +39 -0
  49. samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/installation.rst.txt +49 -0
  50. samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/usage.rst.txt +181 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ samsledje-D-SCRIPT-8a55490/data/pairs/human_train.tsv filter=lfs diff=lfs merge=lfs -text
37
+ samsledje-D-SCRIPT-8a55490/data/seqs/human.fasta filter=lfs diff=lfs merge=lfs -text
38
+ samsledje-D-SCRIPT-8a55490/data/seqs/mouse.fasta filter=lfs diff=lfs merge=lfs -text
samsledje-D-SCRIPT-8a55490/.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ build/*
2
+ scratch/*
3
+ dist/*
4
+ **/*.pt
5
+ .vscode/**
6
+ *.egg-info
7
+ **/*.h5
8
+ **/.ipynb_checkpoints/**
9
+ **/__pycache__/**
10
+ collect_env.py
samsledje-D-SCRIPT-8a55490/.readthedocs.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # .readthedocs.yml
2
+ # Read the Docs configuration file
3
+ # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4
+
5
+ # Required
6
+ version: 2
7
+
8
+ python:
9
+ version: 3.7
10
+ install:
11
+ - requirements: docs/requirements.txt
12
+ - method: pip
13
+ path: .
14
+
15
+ # Build documentation in the docs/ directory with Sphinx
16
+ sphinx:
17
+ configuration: docs/source/conf.py
18
+
19
+ # Build documentation with MkDocs
20
+ #mkdocs:
21
+ # configuration: mkdocs.yml
22
+
23
+ # Optionally build your docs in additional formats such as PDF
24
+ formats:
25
+ - pdf
samsledje-D-SCRIPT-8a55490/CHANGELOG.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # To Do
2
+ - Full logging system (issue #5)
3
+ - Add multi-gpu support (issue #6)
4
+ - Use multiple workers to load embeddings / support for loading embeddings on the fly to reduce memory usage (issue #8/11)
5
+ - Add convenience function to generate candidates - all pairs from a list / cartesian produt of multiple lists
6
+ - Add error handling for calledProcessError in utils.gpu_mem
7
+
8
+ # v0
9
+
10
+ ## v0.1
11
+
12
+ ### v0.1.5: 2021-06-23 -- Bug Fix - Augment and Documentation
13
+ - Updated package level imports
14
+ - Updated documentation
15
+ - Fixed issue #13: improper augmentation of data
16
+ - Fixed issue #12: overwrites cmap data sets if they already exist
17
+
18
+ ### v0.1.4: 2021-03-05 -- Bug Fix - Typo in `ContactModule.forward()`
19
+ - Fixed issue #7: bug which would crash contact module if called directly
20
+
21
+ ### v0.1.3: 2021-02-17 -- Bug Fix - Pairs too large for GPU
22
+ - Fixed issues #3, #4
23
+ - Basic logging system implemented to report skipped pairs
24
+ - Fixed wrong variable name in loading from sequence file
25
+ - Updated documentation
26
+
27
+ ### v0.1.2: 2020-11-30 -- Bug Fix - Eval Mode
28
+ - Model should be put into `eval()` mode before prediction or evaluation, and when new models are downloaded - this makes the output deterministic by disabling dropout layers
29
+
30
+ ### v0.1.1: 2020-11-18 -- First Beta Release
samsledje-D-SCRIPT-8a55490/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Samuel Sledzieski
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
samsledje-D-SCRIPT-8a55490/README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # D-SCRIPT
2
+ ![D-SCRIPT Architecture](docs/source/img/dscript_architecture.png)
3
+
4
+ <!--- #![GitHub release (latest by date)](https://img.shields.io/github/v/release/samsledje/D-SCRIPT) --->
5
+ [![D-SCRIPT](https://img.shields.io/github/v/release/samsledje/D-SCRIPT?include_prereleases)](https://github.com/samsledje/D-SCRIPT/releases)
6
+ [![PyPI](https://img.shields.io/pypi/v/dscript)](https://pypi.org/project/dscript/)
7
+ [![Documentation Status](https://readthedocs.org/projects/d-script/badge/?version=main)](https://d-script.readthedocs.io/en/main/?badge=main)
8
+ [![License](https://img.shields.io/github/license/samsledje/D-SCRIPT)](https://github.com/samsledje/D-SCRIPT/blob/main/LICENSE)
9
+ [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
10
+
11
+
12
+ D-SCRIPT is a deep learning method for predicting a physical interaction between two proteins given just their sequences. It generalizes well to new species and is robust to limitations in training data size. Its design reflects the intuition that for two proteins to physically interact, a subset of amino acids from each protein should be in con-tact with the other. The intermediate stages of D-SCRIPT directly implement this intuition, with the penultimate stage in D-SCRIPT being a rough estimate of the inter-protein contact map of the protein dimer. This structurally-motivated design enhances the interpretability of the results and, since structure is more conserved evolutionarily than sequence, improves generalizability across species.
13
+
14
+ - D-SCRIPT is described in the paper [“Sequence-based prediction of protein-protein interactions: a structure-aware interpretable deep learning model”](https://www.biorxiv.org/content/10.1101/2021.01.22.427866v1) by Sam Sledzieski, Rohit Singh, Lenore Cowen and Bonnie Berger.
15
+
16
+ - [Homepage](http://dscript.csail.mit.edu)
17
+
18
+ - [Documentation](https://d-script.readthedocs.io/en/main/)
samsledje-D-SCRIPT-8a55490/data/pairs/ecoli_test.tsv ADDED
The diff for this file is too large to render. See raw diff
 
samsledje-D-SCRIPT-8a55490/data/pairs/fly_test.tsv ADDED
The diff for this file is too large to render. See raw diff
 
samsledje-D-SCRIPT-8a55490/data/pairs/human_test.tsv ADDED
The diff for this file is too large to render. See raw diff
 
samsledje-D-SCRIPT-8a55490/data/pairs/human_train.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a51caf0b590decf96911b09d1e2cc6afc9a9d669d4e67d8bb3c2f1c94e16cd0b
3
+ size 18558848
samsledje-D-SCRIPT-8a55490/data/pairs/mouse_test.tsv ADDED
The diff for this file is too large to render. See raw diff
 
samsledje-D-SCRIPT-8a55490/data/pairs/worm_test.tsv ADDED
The diff for this file is too large to render. See raw diff
 
samsledje-D-SCRIPT-8a55490/data/pairs/yeast_test.tsv ADDED
The diff for this file is too large to render. See raw diff
 
samsledje-D-SCRIPT-8a55490/data/seqs/ecoli.fasta ADDED
The diff for this file is too large to render. See raw diff
 
samsledje-D-SCRIPT-8a55490/data/seqs/fly.fasta ADDED
The diff for this file is too large to render. See raw diff
 
samsledje-D-SCRIPT-8a55490/data/seqs/human.fasta ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff027c405225204c9c3469ee2aa6dee807253a00af96936f4776e3580319cb14
3
+ size 30928643
samsledje-D-SCRIPT-8a55490/data/seqs/mouse.fasta ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17a22b563760987f5ba4732e55b02327eb2862b65e9685c0207e190bd4092140
3
+ size 17029641
samsledje-D-SCRIPT-8a55490/data/seqs/worm.fasta ADDED
The diff for this file is too large to render. See raw diff
 
samsledje-D-SCRIPT-8a55490/data/seqs/yeast.fasta ADDED
The diff for this file is too large to render. See raw diff
 
samsledje-D-SCRIPT-8a55490/docs/.nojekyll ADDED
File without changes
samsledje-D-SCRIPT-8a55490/docs/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line, and also
5
+ # from the environment for the first two.
6
+ SPHINXOPTS ?=
7
+ SPHINXBUILD ?= sphinx-build
8
+ SOURCEDIR = source
9
+ BUILDDIR = build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
samsledje-D-SCRIPT-8a55490/docs/build/doctrees/api/dscript.commands.doctree ADDED
Binary file (53.1 kB). View file
 
samsledje-D-SCRIPT-8a55490/docs/build/doctrees/api/dscript.models.doctree ADDED
Binary file (119 kB). View file
 
samsledje-D-SCRIPT-8a55490/docs/build/doctrees/api/index.doctree ADDED
Binary file (102 kB). View file
 
samsledje-D-SCRIPT-8a55490/docs/build/doctrees/data.doctree ADDED
Binary file (11.2 kB). View file
 
samsledje-D-SCRIPT-8a55490/docs/build/doctrees/environment.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e1c0741329fa1f6bb224188ae77cd9f068b27dd408083cefa453252320c5b19
3
+ size 98992
samsledje-D-SCRIPT-8a55490/docs/build/doctrees/index.doctree ADDED
Binary file (9.5 kB). View file
 
samsledje-D-SCRIPT-8a55490/docs/build/doctrees/installation.doctree ADDED
Binary file (6.63 kB). View file
 
samsledje-D-SCRIPT-8a55490/docs/build/doctrees/usage.doctree ADDED
Binary file (17.9 kB). View file
 
samsledje-D-SCRIPT-8a55490/docs/build/html/.buildinfo ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Sphinx build info version 1
2
+ # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
3
+ config: 12603d33db63e6503ae2eeaded0b39ac
4
+ tags: 645f666f9bcd5a90fca523b33c5a78b7
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/alphabets.html ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8">
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
9
+
10
+ <title>dscript.alphabets &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+ <!--[if lt IE 9]>
24
+ <script src="../../_static/js/html5shiv.min.js"></script>
25
+ <![endif]-->
26
+
27
+
28
+ <script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
29
+ <script src="../../_static/jquery.js"></script>
30
+ <script src="../../_static/underscore.js"></script>
31
+ <script src="../../_static/doctools.js"></script>
32
+ <script src="../../_static/language_data.js"></script>
33
+ <script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
34
+
35
+ <script type="text/javascript" src="../../_static/js/theme.js"></script>
36
+
37
+
38
+ <link rel="index" title="Index" href="../../genindex.html" />
39
+ <link rel="search" title="Search" href="../../search.html" />
40
+ </head>
41
+
42
+ <body class="wy-body-for-nav">
43
+
44
+
45
+ <div class="wy-grid-for-nav">
46
+
47
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
48
+ <div class="wy-side-scroll">
49
+ <div class="wy-side-nav-search" >
50
+
51
+
52
+
53
+ <a href="../../index.html" class="icon icon-home" alt="Documentation Home"> D-SCRIPT
54
+
55
+
56
+
57
+ </a>
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+ <div role="search">
66
+ <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
67
+ <input type="text" name="q" placeholder="Search docs" />
68
+ <input type="hidden" name="check_keywords" value="yes" />
69
+ <input type="hidden" name="area" value="default" />
70
+ </form>
71
+ </div>
72
+
73
+
74
+ </div>
75
+
76
+
77
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
78
+
79
+
80
+
81
+
82
+
83
+
84
+ <ul>
85
+ <li class="toctree-l1"><a class="reference internal" href="../../installation.html">Installation</a></li>
86
+ <li class="toctree-l1"><a class="reference internal" href="../../usage.html">Usage</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../../data.html">Data</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../../api/index.html">API</a></li>
89
+ </ul>
90
+
91
+
92
+
93
+ </div>
94
+
95
+ </div>
96
+ </nav>
97
+
98
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
99
+
100
+
101
+ <nav class="wy-nav-top" aria-label="top navigation">
102
+
103
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
104
+ <a href="../../index.html">D-SCRIPT</a>
105
+
106
+ </nav>
107
+
108
+
109
+ <div class="wy-nav-content">
110
+
111
+ <div class="rst-content">
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+ <div role="navigation" aria-label="breadcrumbs navigation">
130
+
131
+ <ul class="wy-breadcrumbs">
132
+
133
+ <li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
134
+
135
+ <li><a href="../index.html">Module code</a> &raquo;</li>
136
+
137
+ <li>dscript.alphabets</li>
138
+
139
+
140
+ <li class="wy-breadcrumbs-aside">
141
+
142
+ </li>
143
+
144
+ </ul>
145
+
146
+
147
+ <hr/>
148
+ </div>
149
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
150
+ <div itemprop="articleBody">
151
+
152
+ <h1>Source code for dscript.alphabets</h1><div class="highlight"><pre>
153
+ <span></span><span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">print_function</span><span class="p">,</span> <span class="n">division</span>
154
+
155
+ <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
156
+
157
+
158
+ <div class="viewcode-block" id="Alphabet"><a class="viewcode-back" href="../../api/index.html#dscript.alphabets.Alphabet">[docs]</a><span class="k">class</span> <span class="nc">Alphabet</span><span class="p">:</span>
159
+ <span class="sd">&quot;&quot;&quot;</span>
160
+ <span class="sd"> From `Bepler &amp; Berger &lt;https://github.com/tbepler/protein-sequence-embedding-iclr2019&gt;`_.</span>
161
+
162
+ <span class="sd"> :param chars: List of characters in alphabet</span>
163
+ <span class="sd"> :type chars: byte str</span>
164
+ <span class="sd"> :param encoding: Mapping of characters to numbers [default: encoding]</span>
165
+ <span class="sd"> :type encoding: np.ndarray</span>
166
+ <span class="sd"> :param mask: Set encoding mask [default: False]</span>
167
+ <span class="sd"> :type mask: bool</span>
168
+ <span class="sd"> :param missing: Number to use for a value outside the alphabet [default: 255]</span>
169
+ <span class="sd"> :type missing: int</span>
170
+ <span class="sd"> &quot;&quot;&quot;</span>
171
+
172
+ <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">chars</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">missing</span><span class="o">=</span><span class="mi">255</span><span class="p">):</span>
173
+ <span class="bp">self</span><span class="o">.</span><span class="n">chars</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">frombuffer</span><span class="p">(</span><span class="n">chars</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
174
+ <span class="bp">self</span><span class="o">.</span><span class="n">encoding</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="o">+</span> <span class="n">missing</span>
175
+ <span class="k">if</span> <span class="n">encoding</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
176
+ <span class="bp">self</span><span class="o">.</span><span class="n">encoding</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">chars</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chars</span><span class="p">))</span>
177
+ <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chars</span><span class="p">)</span>
178
+ <span class="k">else</span><span class="p">:</span>
179
+ <span class="bp">self</span><span class="o">.</span><span class="n">encoding</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">chars</span><span class="p">]</span> <span class="o">=</span> <span class="n">encoding</span>
180
+ <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">=</span> <span class="n">encoding</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
181
+ <span class="bp">self</span><span class="o">.</span><span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span>
182
+ <span class="k">if</span> <span class="n">mask</span><span class="p">:</span>
183
+ <span class="bp">self</span><span class="o">.</span><span class="n">size</span> <span class="o">-=</span> <span class="mi">1</span>
184
+
185
+ <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
186
+ <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span>
187
+
188
+ <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">i</span><span class="p">):</span>
189
+ <span class="k">return</span> <span class="nb">chr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">chars</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
190
+
191
+ <div class="viewcode-block" id="Alphabet.encode"><a class="viewcode-back" href="../../api/index.html#dscript.alphabets.Alphabet.encode">[docs]</a> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
192
+ <span class="sd">&quot;&quot;&quot;</span>
193
+ <span class="sd"> Encode a byte string into alphabet indices</span>
194
+ <span class="sd"> </span>
195
+ <span class="sd"> :param x: Amino acid string</span>
196
+ <span class="sd"> :type x: byte str</span>
197
+ <span class="sd"> :return: Numeric encoding</span>
198
+ <span class="sd"> :rtype: np.ndarray</span>
199
+ <span class="sd"> &quot;&quot;&quot;</span>
200
+ <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">frombuffer</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
201
+ <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoding</span><span class="p">[</span><span class="n">x</span><span class="p">]</span></div>
202
+
203
+ <div class="viewcode-block" id="Alphabet.decode"><a class="viewcode-back" href="../../api/index.html#dscript.alphabets.Alphabet.decode">[docs]</a> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
204
+ <span class="sd">&quot;&quot;&quot;</span>
205
+ <span class="sd"> Decode numeric encoding to byte string of this alphabet</span>
206
+
207
+ <span class="sd"> :param x: Numeric encoding</span>
208
+ <span class="sd"> :type x: np.ndarray</span>
209
+ <span class="sd"> :return: Amino acid string</span>
210
+ <span class="sd"> :rtype: byte str</span>
211
+ <span class="sd"> &quot;&quot;&quot;</span>
212
+ <span class="n">string</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">chars</span><span class="p">[</span><span class="n">x</span><span class="p">]</span>
213
+ <span class="k">return</span> <span class="n">string</span><span class="o">.</span><span class="n">tobytes</span><span class="p">()</span></div></div>
214
+
215
+
216
+ <div class="viewcode-block" id="Uniprot21"><a class="viewcode-back" href="../../api/index.html#dscript.alphabets.Uniprot21">[docs]</a><span class="k">class</span> <span class="nc">Uniprot21</span><span class="p">(</span><span class="n">Alphabet</span><span class="p">):</span>
217
+ <span class="sd">&quot;&quot;&quot;</span>
218
+ <span class="sd"> Uniprot 21 Amino Acid Encoding.</span>
219
+
220
+ <span class="sd"> From `Bepler &amp; Berger &lt;https://github.com/tbepler/protein-sequence-embedding-iclr2019&gt;`_.</span>
221
+ <span class="sd"> &quot;&quot;&quot;</span>
222
+
223
+ <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
224
+ <span class="n">chars</span> <span class="o">=</span> <span class="sa">b</span><span class="s2">&quot;ARNDCQEGHILKMFPSTWYVXOUBZ&quot;</span>
225
+ <span class="n">encoding</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">chars</span><span class="p">))</span>
226
+ <span class="n">encoding</span><span class="p">[</span><span class="mi">21</span><span class="p">:]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">11</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">20</span><span class="p">]</span> <span class="c1"># encode &#39;OUBZ&#39; as synonyms</span>
227
+ <span class="nb">super</span><span class="p">(</span><span class="n">Uniprot21</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">chars</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="n">encoding</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">missing</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span></div>
228
+ </pre></div>
229
+
230
+ </div>
231
+
232
+ </div>
233
+ <footer>
234
+
235
+
236
+ <hr/>
237
+
238
+ <div role="contentinfo">
239
+ <p>
240
+
241
+ &copy; Copyright 2020, Samuel Sledzieski, Rohit Singh
242
+
243
+ </p>
244
+ </div>
245
+
246
+
247
+
248
+ Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a
249
+
250
+ <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a>
251
+
252
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
253
+
254
+ </footer>
255
+
256
+ </div>
257
+ </div>
258
+
259
+ </section>
260
+
261
+ </div>
262
+
263
+
264
+ <script type="text/javascript">
265
+ jQuery(function () {
266
+ SphinxRtdTheme.Navigation.enable(true);
267
+ });
268
+ </script>
269
+
270
+
271
+
272
+
273
+
274
+
275
+ </body>
276
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/commands/embed.html ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8">
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
9
+
10
+ <title>dscript.commands.embed &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+ <!--[if lt IE 9]>
24
+ <script src="../../../_static/js/html5shiv.min.js"></script>
25
+ <![endif]-->
26
+
27
+
28
+ <script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
29
+ <script src="../../../_static/jquery.js"></script>
30
+ <script src="../../../_static/underscore.js"></script>
31
+ <script src="../../../_static/doctools.js"></script>
32
+ <script src="../../../_static/language_data.js"></script>
33
+ <script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
34
+
35
+ <script type="text/javascript" src="../../../_static/js/theme.js"></script>
36
+
37
+
38
+ <link rel="index" title="Index" href="../../../genindex.html" />
39
+ <link rel="search" title="Search" href="../../../search.html" />
40
+ </head>
41
+
42
+ <body class="wy-body-for-nav">
43
+
44
+
45
+ <div class="wy-grid-for-nav">
46
+
47
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
48
+ <div class="wy-side-scroll">
49
+ <div class="wy-side-nav-search" >
50
+
51
+
52
+
53
+ <a href="../../../index.html" class="icon icon-home" alt="Documentation Home"> D-SCRIPT
54
+
55
+
56
+
57
+ </a>
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+ <div role="search">
66
+ <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
67
+ <input type="text" name="q" placeholder="Search docs" />
68
+ <input type="hidden" name="check_keywords" value="yes" />
69
+ <input type="hidden" name="area" value="default" />
70
+ </form>
71
+ </div>
72
+
73
+
74
+ </div>
75
+
76
+
77
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
78
+
79
+
80
+
81
+
82
+
83
+
84
+ <ul>
85
+ <li class="toctree-l1"><a class="reference internal" href="../../../installation.html">Installation</a></li>
86
+ <li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../../../data.html">Data</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../../../api/index.html">API</a></li>
89
+ </ul>
90
+
91
+
92
+
93
+ </div>
94
+
95
+ </div>
96
+ </nav>
97
+
98
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
99
+
100
+
101
+ <nav class="wy-nav-top" aria-label="top navigation">
102
+
103
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
104
+ <a href="../../../index.html">D-SCRIPT</a>
105
+
106
+ </nav>
107
+
108
+
109
+ <div class="wy-nav-content">
110
+
111
+ <div class="rst-content">
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+ <div role="navigation" aria-label="breadcrumbs navigation">
130
+
131
+ <ul class="wy-breadcrumbs">
132
+
133
+ <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
134
+
135
+ <li><a href="../../index.html">Module code</a> &raquo;</li>
136
+
137
+ <li>dscript.commands.embed</li>
138
+
139
+
140
+ <li class="wy-breadcrumbs-aside">
141
+
142
+ </li>
143
+
144
+ </ul>
145
+
146
+
147
+ <hr/>
148
+ </div>
149
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
150
+ <div itemprop="articleBody">
151
+
152
+ <h1>Source code for dscript.commands.embed</h1><div class="highlight"><pre>
153
+ <span></span><span class="sd">&quot;&quot;&quot;</span>
154
+ <span class="sd">Generate new embeddings using pre-trained language model.</span>
155
+ <span class="sd">&quot;&quot;&quot;</span>
156
+
157
+ <span class="kn">import</span> <span class="nn">argparse</span>
158
+ <span class="kn">from</span> <span class="nn">dscript.language_model</span> <span class="kn">import</span> <span class="n">embed_from_fasta</span>
159
+
160
+
161
+ <div class="viewcode-block" id="add_args"><a class="viewcode-back" href="../../../api/dscript.commands.html#dscript.commands.embed.add_args">[docs]</a><span class="k">def</span> <span class="nf">add_args</span><span class="p">(</span><span class="n">parser</span><span class="p">):</span>
162
+ <span class="sd">&quot;&quot;&quot;</span>
163
+ <span class="sd"> Create parser for command line utility.</span>
164
+
165
+ <span class="sd"> :meta private:</span>
166
+ <span class="sd"> &quot;&quot;&quot;</span>
167
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--seqs&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Sequences to be embedded&quot;</span><span class="p">,</span> <span class="n">required</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
168
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--outfile&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;h5 file to write results&quot;</span><span class="p">,</span> <span class="n">required</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
169
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;-d&quot;</span><span class="p">,</span> <span class="s2">&quot;--device&quot;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Compute device to use&quot;</span><span class="p">)</span>
170
+ <span class="k">return</span> <span class="n">parser</span></div>
171
+
172
+
173
+ <div class="viewcode-block" id="main"><a class="viewcode-back" href="../../../api/dscript.commands.html#dscript.commands.embed.main">[docs]</a><span class="k">def</span> <span class="nf">main</span><span class="p">(</span><span class="n">args</span><span class="p">):</span>
174
+ <span class="sd">&quot;&quot;&quot;</span>
175
+ <span class="sd"> Run embedding from arguments.</span>
176
+
177
+ <span class="sd"> :meta private:</span>
178
+ <span class="sd"> &quot;&quot;&quot;</span>
179
+ <span class="n">inPath</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">seqs</span>
180
+ <span class="n">outPath</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">outfile</span>
181
+ <span class="n">device</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">device</span>
182
+ <span class="n">embed_from_fasta</span><span class="p">(</span><span class="n">inPath</span><span class="p">,</span> <span class="n">outPath</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div>
183
+
184
+
185
+ <span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
186
+ <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="vm">__doc__</span><span class="p">)</span>
187
+ <span class="n">add_args</span><span class="p">(</span><span class="n">parser</span><span class="p">)</span>
188
+ <span class="n">main</span><span class="p">(</span><span class="n">parser</span><span class="o">.</span><span class="n">parse_args</span><span class="p">())</span>
189
+ </pre></div>
190
+
191
+ </div>
192
+
193
+ </div>
194
+ <footer>
195
+
196
+
197
+ <hr/>
198
+
199
+ <div role="contentinfo">
200
+ <p>
201
+
202
+ &copy; Copyright 2020, Samuel Sledzieski, Rohit Singh
203
+
204
+ </p>
205
+ </div>
206
+
207
+
208
+
209
+ Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a
210
+
211
+ <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a>
212
+
213
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
214
+
215
+ </footer>
216
+
217
+ </div>
218
+ </div>
219
+
220
+ </section>
221
+
222
+ </div>
223
+
224
+
225
+ <script type="text/javascript">
226
+ jQuery(function () {
227
+ SphinxRtdTheme.Navigation.enable(true);
228
+ });
229
+ </script>
230
+
231
+
232
+
233
+
234
+
235
+
236
+ </body>
237
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/commands/eval.html ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8" />
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
+
10
+ <title>dscript.commands.eval &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+ <!--[if lt IE 9]>
27
+ <script src="../../../_static/js/html5shiv.min.js"></script>
28
+ <![endif]-->
29
+
30
+
31
+ <script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
32
+ <script src="../../../_static/jquery.js"></script>
33
+ <script src="../../../_static/underscore.js"></script>
34
+ <script src="../../../_static/doctools.js"></script>
35
+
36
+ <script type="text/javascript" src="../../../_static/js/theme.js"></script>
37
+
38
+
39
+ <link rel="index" title="Index" href="../../../genindex.html" />
40
+ <link rel="search" title="Search" href="../../../search.html" />
41
+ </head>
42
+
43
+ <body class="wy-body-for-nav">
44
+
45
+
46
+ <div class="wy-grid-for-nav">
47
+
48
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
49
+ <div class="wy-side-scroll">
50
+ <div class="wy-side-nav-search" >
51
+
52
+
53
+
54
+ <a href="../../../index.html" class="icon icon-home"> D-SCRIPT
55
+
56
+
57
+
58
+ </a>
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+ <div role="search">
67
+ <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
68
+ <input type="text" name="q" placeholder="Search docs" />
69
+ <input type="hidden" name="check_keywords" value="yes" />
70
+ <input type="hidden" name="area" value="default" />
71
+ </form>
72
+ </div>
73
+
74
+
75
+ </div>
76
+
77
+
78
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
79
+
80
+
81
+
82
+
83
+
84
+
85
+ <ul>
86
+ <li class="toctree-l1"><a class="reference internal" href="../../../installation.html">Installation</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../../../data.html">Data</a></li>
89
+ <li class="toctree-l1"><a class="reference internal" href="../../../api/index.html">API</a></li>
90
+ </ul>
91
+
92
+
93
+
94
+ </div>
95
+
96
+ </div>
97
+ </nav>
98
+
99
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
100
+
101
+
102
+ <nav class="wy-nav-top" aria-label="top navigation">
103
+
104
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
105
+ <a href="../../../index.html">D-SCRIPT</a>
106
+
107
+ </nav>
108
+
109
+
110
+ <div class="wy-nav-content">
111
+
112
+ <div class="rst-content">
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+ <div role="navigation" aria-label="breadcrumbs navigation">
133
+
134
+ <ul class="wy-breadcrumbs">
135
+
136
+ <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
137
+
138
+ <li><a href="../../index.html">Module code</a> &raquo;</li>
139
+
140
+ <li>dscript.commands.eval</li>
141
+
142
+
143
+ <li class="wy-breadcrumbs-aside">
144
+
145
+ </li>
146
+
147
+ </ul>
148
+
149
+
150
+ <hr/>
151
+ </div>
152
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
153
+ <div itemprop="articleBody">
154
+
155
+ <h1>Source code for dscript.commands.eval</h1><div class="highlight"><pre>
156
+ <span></span><span class="sd">&quot;&quot;&quot;</span>
157
+ <span class="sd">Evaluate a trained model.</span>
158
+ <span class="sd">&quot;&quot;&quot;</span>
159
+
160
+ <span class="kn">import</span> <span class="nn">sys</span><span class="o">,</span> <span class="nn">os</span>
161
+ <span class="kn">import</span> <span class="nn">argparse</span>
162
+ <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
163
+ <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span>
164
+ <span class="kn">import</span> <span class="nn">torch</span>
165
+ <span class="kn">import</span> <span class="nn">h5py</span>
166
+ <span class="kn">import</span> <span class="nn">datetime</span>
167
+ <span class="kn">import</span> <span class="nn">matplotlib</span>
168
+
169
+ <span class="n">matplotlib</span><span class="o">.</span><span class="n">use</span><span class="p">(</span><span class="s2">&quot;Agg&quot;</span><span class="p">)</span>
170
+ <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
171
+ <span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="p">(</span>
172
+ <span class="n">precision_recall_curve</span><span class="p">,</span>
173
+ <span class="n">average_precision_score</span><span class="p">,</span>
174
+ <span class="n">roc_curve</span><span class="p">,</span>
175
+ <span class="n">roc_auc_score</span><span class="p">,</span>
176
+ <span class="p">)</span>
177
+ <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
178
+
179
+
180
+ <span class="k">def</span> <span class="nf">add_args</span><span class="p">(</span><span class="n">parser</span><span class="p">):</span>
181
+ <span class="sd">&quot;&quot;&quot;</span>
182
+ <span class="sd"> Create parser for command line utility.</span>
183
+
184
+ <span class="sd"> :meta private:</span>
185
+ <span class="sd"> &quot;&quot;&quot;</span>
186
+
187
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--model&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Trained prediction model&quot;</span><span class="p">,</span> <span class="n">required</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
188
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--test&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Test Data&quot;</span><span class="p">,</span> <span class="n">required</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
189
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--embedding&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;h5 file with embedded sequences&quot;</span><span class="p">,</span> <span class="n">required</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
190
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;-o&quot;</span><span class="p">,</span> <span class="s2">&quot;--outfile&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Output file to write results&quot;</span><span class="p">)</span>
191
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;-d&quot;</span><span class="p">,</span> <span class="s2">&quot;--device&quot;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Compute device to use&quot;</span><span class="p">)</span>
192
+ <span class="k">return</span> <span class="n">parser</span>
193
+
194
+
195
+ <div class="viewcode-block" id="plot_eval_predictions"><a class="viewcode-back" href="../../../api/dscript.commands.html#dscript.commands.eval.plot_eval_predictions">[docs]</a><span class="k">def</span> <span class="nf">plot_eval_predictions</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predictions</span><span class="p">,</span> <span class="n">path</span><span class="o">=</span><span class="s2">&quot;figure&quot;</span><span class="p">):</span>
196
+ <span class="sd">&quot;&quot;&quot;</span>
197
+ <span class="sd"> Plot histogram of positive and negative predictions, precision-recall curve, and receiver operating characteristic curve.</span>
198
+
199
+ <span class="sd"> :param y: Labels</span>
200
+ <span class="sd"> :type y: np.ndarray</span>
201
+ <span class="sd"> :param phat: Predicted probabilities</span>
202
+ <span class="sd"> :type phat: np.ndarray</span>
203
+ <span class="sd"> :param path: File prefix for plots to be saved to [default: figure]</span>
204
+ <span class="sd"> :type path: str</span>
205
+ <span class="sd"> &quot;&quot;&quot;</span>
206
+
207
+ <span class="n">pos_phat</span> <span class="o">=</span> <span class="n">predictions</span><span class="p">[</span><span class="n">labels</span> <span class="o">==</span> <span class="mi">1</span><span class="p">]</span>
208
+ <span class="n">neg_phat</span> <span class="o">=</span> <span class="n">predictions</span><span class="p">[</span><span class="n">labels</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span>
209
+
210
+ <span class="n">fig</span><span class="p">,</span> <span class="p">(</span><span class="n">ax1</span><span class="p">,</span> <span class="n">ax2</span><span class="p">)</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
211
+ <span class="n">fig</span><span class="o">.</span><span class="n">suptitle</span><span class="p">(</span><span class="s2">&quot;Distribution of Predictions&quot;</span><span class="p">)</span>
212
+ <span class="n">ax1</span><span class="o">.</span><span class="n">hist</span><span class="p">(</span><span class="n">pos_phat</span><span class="p">)</span>
213
+ <span class="n">ax1</span><span class="o">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
214
+ <span class="n">ax1</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Positive&quot;</span><span class="p">)</span>
215
+ <span class="n">ax1</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s2">&quot;p-hat&quot;</span><span class="p">)</span>
216
+ <span class="n">ax2</span><span class="o">.</span><span class="n">hist</span><span class="p">(</span><span class="n">neg_phat</span><span class="p">)</span>
217
+ <span class="n">ax2</span><span class="o">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
218
+ <span class="n">ax2</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Negative&quot;</span><span class="p">)</span>
219
+ <span class="n">ax2</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s2">&quot;p-hat&quot;</span><span class="p">)</span>
220
+ <span class="n">plt</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="n">path</span> <span class="o">+</span> <span class="s2">&quot;.phat_dist.png&quot;</span><span class="p">)</span>
221
+ <span class="n">plt</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
222
+
223
+ <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span><span class="p">,</span> <span class="n">pr_thresh</span> <span class="o">=</span> <span class="n">precision_recall_curve</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predictions</span><span class="p">)</span>
224
+ <span class="n">aupr</span> <span class="o">=</span> <span class="n">average_precision_score</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predictions</span><span class="p">)</span>
225
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;AUPR:&quot;</span><span class="p">,</span> <span class="n">aupr</span><span class="p">)</span>
226
+
227
+ <span class="n">plt</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">recall</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s2">&quot;b&quot;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">)</span>
228
+ <span class="n">plt</span><span class="o">.</span><span class="n">fill_between</span><span class="p">(</span><span class="n">recall</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s2">&quot;b&quot;</span><span class="p">)</span>
229
+ <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">&quot;Recall&quot;</span><span class="p">)</span>
230
+ <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s2">&quot;Precision&quot;</span><span class="p">)</span>
231
+ <span class="n">plt</span><span class="o">.</span><span class="n">ylim</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.05</span><span class="p">])</span>
232
+ <span class="n">plt</span><span class="o">.</span><span class="n">xlim</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">])</span>
233
+ <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Precision-Recall (AUPR: </span><span class="si">{:.3}</span><span class="s2">)&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">aupr</span><span class="p">))</span>
234
+ <span class="n">plt</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="n">path</span> <span class="o">+</span> <span class="s2">&quot;.aupr.png&quot;</span><span class="p">)</span>
235
+ <span class="n">plt</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
236
+
237
+ <span class="n">fpr</span><span class="p">,</span> <span class="n">tpr</span><span class="p">,</span> <span class="n">roc_thresh</span> <span class="o">=</span> <span class="n">roc_curve</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predictions</span><span class="p">)</span>
238
+ <span class="n">auroc</span> <span class="o">=</span> <span class="n">roc_auc_score</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predictions</span><span class="p">)</span>
239
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;AUROC:&quot;</span><span class="p">,</span> <span class="n">auroc</span><span class="p">)</span>
240
+
241
+ <span class="n">plt</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">fpr</span><span class="p">,</span> <span class="n">tpr</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s2">&quot;b&quot;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">)</span>
242
+ <span class="n">plt</span><span class="o">.</span><span class="n">fill_between</span><span class="p">(</span><span class="n">fpr</span><span class="p">,</span> <span class="n">tpr</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s2">&quot;b&quot;</span><span class="p">)</span>
243
+ <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">&quot;FPR&quot;</span><span class="p">)</span>
244
+ <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s2">&quot;TPR&quot;</span><span class="p">)</span>
245
+ <span class="n">plt</span><span class="o">.</span><span class="n">ylim</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.05</span><span class="p">])</span>
246
+ <span class="n">plt</span><span class="o">.</span><span class="n">xlim</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">])</span>
247
+ <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Receiver Operating Characteristic (AUROC: </span><span class="si">{:.3}</span><span class="s2">)&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">auroc</span><span class="p">))</span>
248
+ <span class="n">plt</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="n">path</span> <span class="o">+</span> <span class="s2">&quot;.auroc.png&quot;</span><span class="p">)</span>
249
+ <span class="n">plt</span><span class="o">.</span><span class="n">close</span><span class="p">()</span></div>
250
+
251
+
252
+ <span class="k">def</span> <span class="nf">main</span><span class="p">(</span><span class="n">args</span><span class="p">):</span>
253
+ <span class="sd">&quot;&quot;&quot;</span>
254
+ <span class="sd"> Run model evaluation from arguments.</span>
255
+
256
+ <span class="sd"> :meta private:</span>
257
+ <span class="sd"> &quot;&quot;&quot;</span>
258
+
259
+ <span class="c1"># Set Device</span>
260
+ <span class="n">device</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">device</span>
261
+ <span class="n">use_cuda</span> <span class="o">=</span> <span class="p">(</span><span class="n">device</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span>
262
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
263
+ <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
264
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;# Using CUDA device </span><span class="si">{</span><span class="n">device</span><span class="si">}</span><span class="s2"> - </span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">get_device_name</span><span class="p">(</span><span class="n">device</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
265
+ <span class="k">else</span><span class="p">:</span>
266
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Using CPU&quot;</span><span class="p">)</span>
267
+
268
+ <span class="c1"># Load Model</span>
269
+ <span class="n">model_path</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">model</span>
270
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
271
+ <span class="n">model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">model_path</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
272
+ <span class="k">else</span><span class="p">:</span>
273
+ <span class="n">model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">model_path</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
274
+ <span class="n">model</span><span class="o">.</span><span class="n">use_cuda</span> <span class="o">=</span> <span class="kc">False</span>
275
+
276
+ <span class="n">embeddingPath</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">embedding</span>
277
+ <span class="n">h5fi</span> <span class="o">=</span> <span class="n">h5py</span><span class="o">.</span><span class="n">File</span><span class="p">(</span><span class="n">embeddingPath</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">)</span>
278
+
279
+ <span class="c1"># Load Pairs</span>
280
+ <span class="n">test_fi</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">test</span>
281
+ <span class="n">test_df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">test_fi</span><span class="p">,</span> <span class="n">sep</span><span class="o">=</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">header</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
282
+
283
+ <span class="k">if</span> <span class="n">args</span><span class="o">.</span><span class="n">outfile</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
284
+ <span class="n">outPath</span> <span class="o">=</span> <span class="n">datetime</span><span class="o">.</span><span class="n">datetime</span><span class="o">.</span><span class="n">now</span><span class="p">()</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s2">&quot;%Y-%m-</span><span class="si">%d</span><span class="s2">-%H-%M&quot;</span><span class="p">)</span>
285
+ <span class="k">else</span><span class="p">:</span>
286
+ <span class="n">outPath</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">outfile</span>
287
+ <span class="n">outFile</span> <span class="o">=</span> <span class="nb">open</span><span class="p">(</span><span class="n">outPath</span> <span class="o">+</span> <span class="s2">&quot;.predictions.tsv&quot;</span><span class="p">,</span> <span class="s2">&quot;w+&quot;</span><span class="p">)</span>
288
+
289
+ <span class="n">allProteins</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">test_df</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">test_df</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
290
+
291
+ <span class="n">seqEmbDict</span> <span class="o">=</span> <span class="p">{}</span>
292
+ <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">allProteins</span><span class="p">,</span> <span class="n">desc</span><span class="o">=</span><span class="s2">&quot;Loading embeddings&quot;</span><span class="p">):</span>
293
+ <span class="n">seqEmbDict</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">h5fi</span><span class="p">[</span><span class="n">i</span><span class="p">][:])</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
294
+
295
+ <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
296
+ <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
297
+ <span class="n">phats</span> <span class="o">=</span> <span class="p">[]</span>
298
+ <span class="n">labels</span> <span class="o">=</span> <span class="p">[]</span>
299
+ <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="p">(</span><span class="n">n0</span><span class="p">,</span> <span class="n">n1</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">test_df</span><span class="o">.</span><span class="n">iterrows</span><span class="p">(),</span> <span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">test_df</span><span class="p">),</span> <span class="n">desc</span><span class="o">=</span><span class="s2">&quot;Predicting pairs&quot;</span><span class="p">):</span>
300
+ <span class="k">try</span><span class="p">:</span>
301
+ <span class="n">p0</span> <span class="o">=</span> <span class="n">seqEmbDict</span><span class="p">[</span><span class="n">n0</span><span class="p">]</span>
302
+ <span class="n">p1</span> <span class="o">=</span> <span class="n">seqEmbDict</span><span class="p">[</span><span class="n">n1</span><span class="p">]</span>
303
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
304
+ <span class="n">p0</span> <span class="o">=</span> <span class="n">p0</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
305
+ <span class="n">p1</span> <span class="o">=</span> <span class="n">p1</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
306
+
307
+ <span class="n">pred</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">p0</span><span class="p">,</span> <span class="n">p1</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
308
+ <span class="n">phats</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pred</span><span class="p">)</span>
309
+ <span class="n">labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">label</span><span class="p">)</span>
310
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">{}</span><span class="se">\t</span><span class="si">{}</span><span class="se">\t</span><span class="si">{}</span><span class="se">\t</span><span class="si">{:.5}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">n0</span><span class="p">,</span> <span class="n">n1</span><span class="p">,</span> <span class="n">label</span><span class="p">,</span> <span class="n">pred</span><span class="p">),</span> <span class="n">file</span><span class="o">=</span><span class="n">outFile</span><span class="p">)</span>
311
+ <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
312
+ <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">{}</span><span class="s2"> x </span><span class="si">{}</span><span class="s2"> - </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">n0</span><span class="p">,</span> <span class="n">n1</span><span class="p">,</span> <span class="n">e</span><span class="p">))</span>
313
+
314
+ <span class="n">phats</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">phats</span><span class="p">)</span>
315
+ <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
316
+ <span class="n">plot_eval_predictions</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">phats</span><span class="p">,</span> <span class="n">outPath</span><span class="p">)</span>
317
+
318
+ <span class="n">outFile</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
319
+ <span class="n">h5fi</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
320
+
321
+
322
+ <span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
323
+ <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="vm">__doc__</span><span class="p">)</span>
324
+ <span class="n">add_args</span><span class="p">(</span><span class="n">parser</span><span class="p">)</span>
325
+ <span class="n">main</span><span class="p">(</span><span class="n">parser</span><span class="o">.</span><span class="n">parse_args</span><span class="p">())</span>
326
+ </pre></div>
327
+
328
+ </div>
329
+
330
+ </div>
331
+ <footer>
332
+
333
+ <hr/>
334
+
335
+ <div role="contentinfo">
336
+ <p>
337
+ &#169; Copyright 2020, Samuel Sledzieski, Rohit Singh.
338
+
339
+ </p>
340
+ </div>
341
+
342
+
343
+
344
+ Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
345
+
346
+ <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
347
+
348
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
349
+
350
+ </footer>
351
+ </div>
352
+ </div>
353
+
354
+ </section>
355
+
356
+ </div>
357
+
358
+
359
+ <script type="text/javascript">
360
+ jQuery(function () {
361
+ SphinxRtdTheme.Navigation.enable(true);
362
+ });
363
+ </script>
364
+
365
+
366
+
367
+
368
+
369
+
370
+ </body>
371
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/commands/predict.html ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8">
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
9
+
10
+ <title>dscript.commands.predict &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+ <!--[if lt IE 9]>
24
+ <script src="../../../_static/js/html5shiv.min.js"></script>
25
+ <![endif]-->
26
+
27
+
28
+ <script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
29
+ <script src="../../../_static/jquery.js"></script>
30
+ <script src="../../../_static/underscore.js"></script>
31
+ <script src="../../../_static/doctools.js"></script>
32
+ <script src="../../../_static/language_data.js"></script>
33
+ <script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
34
+
35
+ <script type="text/javascript" src="../../../_static/js/theme.js"></script>
36
+
37
+
38
+ <link rel="index" title="Index" href="../../../genindex.html" />
39
+ <link rel="search" title="Search" href="../../../search.html" />
40
+ </head>
41
+
42
+ <body class="wy-body-for-nav">
43
+
44
+
45
+ <div class="wy-grid-for-nav">
46
+
47
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
48
+ <div class="wy-side-scroll">
49
+ <div class="wy-side-nav-search" >
50
+
51
+
52
+
53
+ <a href="../../../index.html" class="icon icon-home" alt="Documentation Home"> D-SCRIPT
54
+
55
+
56
+
57
+ </a>
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+ <div role="search">
66
+ <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
67
+ <input type="text" name="q" placeholder="Search docs" />
68
+ <input type="hidden" name="check_keywords" value="yes" />
69
+ <input type="hidden" name="area" value="default" />
70
+ </form>
71
+ </div>
72
+
73
+
74
+ </div>
75
+
76
+
77
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
78
+
79
+
80
+
81
+
82
+
83
+
84
+ <ul>
85
+ <li class="toctree-l1"><a class="reference internal" href="../../../installation.html">Installation</a></li>
86
+ <li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../../../data.html">Data</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../../../api/index.html">API</a></li>
89
+ </ul>
90
+
91
+
92
+
93
+ </div>
94
+
95
+ </div>
96
+ </nav>
97
+
98
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
99
+
100
+
101
+ <nav class="wy-nav-top" aria-label="top navigation">
102
+
103
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
104
+ <a href="../../../index.html">D-SCRIPT</a>
105
+
106
+ </nav>
107
+
108
+
109
+ <div class="wy-nav-content">
110
+
111
+ <div class="rst-content">
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+ <div role="navigation" aria-label="breadcrumbs navigation">
130
+
131
+ <ul class="wy-breadcrumbs">
132
+
133
+ <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
134
+
135
+ <li><a href="../../index.html">Module code</a> &raquo;</li>
136
+
137
+ <li>dscript.commands.predict</li>
138
+
139
+
140
+ <li class="wy-breadcrumbs-aside">
141
+
142
+ </li>
143
+
144
+ </ul>
145
+
146
+
147
+ <hr/>
148
+ </div>
149
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
150
+ <div itemprop="articleBody">
151
+
152
+ <h1>Source code for dscript.commands.predict</h1><div class="highlight"><pre>
153
+ <span></span><span class="sd">&quot;&quot;&quot;</span>
154
+ <span class="sd">Make new predictions with a pre-trained model.</span>
155
+ <span class="sd">&quot;&quot;&quot;</span>
156
+ <span class="kn">import</span> <span class="nn">sys</span><span class="o">,</span> <span class="nn">os</span>
157
+ <span class="kn">import</span> <span class="nn">torch</span>
158
+ <span class="kn">import</span> <span class="nn">h5py</span>
159
+ <span class="kn">import</span> <span class="nn">argparse</span>
160
+ <span class="kn">import</span> <span class="nn">datetime</span>
161
+ <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span>
162
+ <span class="kn">from</span> <span class="nn">scipy.special</span> <span class="kn">import</span> <span class="n">comb</span>
163
+ <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
164
+
165
+ <span class="kn">from</span> <span class="nn">dscript.alphabets</span> <span class="kn">import</span> <span class="n">Uniprot21</span>
166
+ <span class="kn">from</span> <span class="nn">dscript.fasta</span> <span class="kn">import</span> <span class="n">parse</span>
167
+ <span class="kn">from</span> <span class="nn">dscript.language_model</span> <span class="kn">import</span> <span class="n">lm_embed</span>
168
+
169
+
170
+ <div class="viewcode-block" id="add_args"><a class="viewcode-back" href="../../../api/dscript.commands.html#dscript.commands.predict.add_args">[docs]</a><span class="k">def</span> <span class="nf">add_args</span><span class="p">(</span><span class="n">parser</span><span class="p">):</span>
171
+ <span class="sd">&quot;&quot;&quot;</span>
172
+ <span class="sd"> Create parser for command line utility</span>
173
+
174
+ <span class="sd"> :meta private:</span>
175
+ <span class="sd"> &quot;&quot;&quot;</span>
176
+
177
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--pairs&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Candidate protein pairs to predict&quot;</span><span class="p">,</span> <span class="n">required</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
178
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--model&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Pretrained Model&quot;</span><span class="p">,</span> <span class="n">required</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
179
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--seqs&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Protein sequences in .fasta format&quot;</span><span class="p">)</span>
180
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--embeddings&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;h5 file with embedded sequences&quot;</span><span class="p">)</span>
181
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;-o&quot;</span><span class="p">,</span> <span class="s2">&quot;--outfile&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;File for predictions&quot;</span><span class="p">)</span>
182
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;-d&quot;</span><span class="p">,</span> <span class="s2">&quot;--device&quot;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Compute device to use&quot;</span><span class="p">)</span>
183
+ <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span>
184
+ <span class="s2">&quot;--thresh&quot;</span><span class="p">,</span>
185
+ <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span>
186
+ <span class="n">default</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span>
187
+ <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Positive prediction threshold - used to store contact maps and predictions in a separate file. [default: 0.5]&quot;</span><span class="p">,</span>
188
+ <span class="p">)</span>
189
+ <span class="k">return</span> <span class="n">parser</span></div>
190
+
191
+
192
+ <div class="viewcode-block" id="main"><a class="viewcode-back" href="../../../api/dscript.commands.html#dscript.commands.predict.main">[docs]</a><span class="k">def</span> <span class="nf">main</span><span class="p">(</span><span class="n">args</span><span class="p">):</span>
193
+ <span class="sd">&quot;&quot;&quot;</span>
194
+ <span class="sd"> Run new prediction from arguments.</span>
195
+
196
+ <span class="sd"> :meta private:</span>
197
+ <span class="sd"> &quot;&quot;&quot;</span>
198
+ <span class="k">if</span> <span class="n">args</span><span class="o">.</span><span class="n">seqs</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">args</span><span class="o">.</span><span class="n">embeddings</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
199
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;One of --seqs or --embeddings is required.&quot;</span><span class="p">)</span>
200
+ <span class="n">sys</span><span class="o">.</span><span class="n">exit</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
201
+
202
+ <span class="n">csvPath</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">pairs</span>
203
+ <span class="n">modelPath</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">model</span>
204
+ <span class="n">outPath</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">outfile</span>
205
+ <span class="n">seqPath</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">seqs</span>
206
+ <span class="n">embPath</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">embeddings</span>
207
+ <span class="n">device</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">device</span>
208
+ <span class="n">threshold</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">thresh</span>
209
+
210
+ <span class="c1"># Set Outpath</span>
211
+ <span class="k">if</span> <span class="n">outPath</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
212
+ <span class="n">outPath</span> <span class="o">=</span> <span class="n">datetime</span><span class="o">.</span><span class="n">datetime</span><span class="o">.</span><span class="n">now</span><span class="p">()</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s2">&quot;%Y-%m-</span><span class="si">%d</span><span class="s2">-%H-%M.predictions&quot;</span><span class="p">)</span>
213
+
214
+ <span class="c1"># Set Device</span>
215
+ <span class="n">use_cuda</span> <span class="o">=</span> <span class="p">(</span><span class="n">device</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span>
216
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
217
+ <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
218
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;# Using CUDA device </span><span class="si">{device}</span><span class="s2"> - {torch.cuda.get_device_name(device)}&quot;</span><span class="p">)</span>
219
+ <span class="k">else</span><span class="p">:</span>
220
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Using CPU&quot;</span><span class="p">)</span>
221
+
222
+ <span class="c1"># Load Model</span>
223
+ <span class="k">try</span><span class="p">:</span>
224
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
225
+ <span class="n">model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">modelPath</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
226
+ <span class="k">else</span><span class="p">:</span>
227
+ <span class="n">model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">modelPath</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
228
+ <span class="n">model</span><span class="o">.</span><span class="n">use_cuda</span> <span class="o">=</span> <span class="kc">False</span>
229
+ <span class="k">except</span> <span class="ne">FileNotFoundError</span><span class="p">:</span>
230
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;# Model </span><span class="si">{modelPath}</span><span class="s2"> not found&quot;</span><span class="p">)</span>
231
+ <span class="n">sys</span><span class="o">.</span><span class="n">exit</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
232
+
233
+ <span class="c1"># Load Pairs</span>
234
+ <span class="k">try</span><span class="p">:</span>
235
+ <span class="n">pairs</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">csvPath</span><span class="p">,</span> <span class="n">sep</span><span class="o">=</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">header</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
236
+ <span class="n">all_prots</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">pairs</span><span class="o">.</span><span class="n">iloc</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">pairs</span><span class="o">.</span><span class="n">iloc</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]))</span>
237
+ <span class="k">except</span> <span class="ne">FileNotFoundError</span><span class="p">:</span>
238
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;# Pairs File </span><span class="si">{csvPath}</span><span class="s2"> not found&quot;</span><span class="p">)</span>
239
+ <span class="n">sys</span><span class="o">.</span><span class="n">exit</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
240
+
241
+ <span class="c1"># Load Sequences or Embeddings</span>
242
+ <span class="k">if</span> <span class="n">embPath</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
243
+ <span class="k">try</span><span class="p">:</span>
244
+ <span class="n">names</span><span class="p">,</span> <span class="n">seqs</span> <span class="o">=</span> <span class="n">parse</span><span class="p">(</span><span class="nb">open</span><span class="p">(</span><span class="n">seqPath</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">))</span>
245
+ <span class="n">seqDict</span> <span class="o">=</span> <span class="p">{</span><span class="n">n</span><span class="p">:</span> <span class="n">s</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">names</span><span class="p">,</span> <span class="n">seqs</span><span class="p">)}</span>
246
+ <span class="k">except</span> <span class="ne">FileNotFoundError</span><span class="p">:</span>
247
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;# Sequence File </span><span class="si">{fastaPath}</span><span class="s2"> not found&quot;</span><span class="p">)</span>
248
+ <span class="n">sys</span><span class="o">.</span><span class="n">exit</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
249
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Generating Embeddings...&quot;</span><span class="p">)</span>
250
+ <span class="n">embeddings</span> <span class="o">=</span> <span class="p">{}</span>
251
+ <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">all_prots</span><span class="p">):</span>
252
+ <span class="n">embeddings</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">lm_embed</span><span class="p">(</span><span class="n">seqDict</span><span class="p">[</span><span class="n">n</span><span class="p">],</span> <span class="n">use_cuda</span><span class="p">)</span>
253
+ <span class="k">else</span><span class="p">:</span>
254
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Loading Embeddings...&quot;</span><span class="p">)</span>
255
+ <span class="n">embedH5</span> <span class="o">=</span> <span class="n">h5py</span><span class="o">.</span><span class="n">File</span><span class="p">(</span><span class="n">embPath</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">)</span>
256
+ <span class="n">embeddings</span> <span class="o">=</span> <span class="p">{}</span>
257
+ <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">all_prots</span><span class="p">):</span>
258
+ <span class="n">embeddings</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">embedH5</span><span class="p">[</span><span class="n">n</span><span class="p">][:])</span>
259
+ <span class="n">embedH5</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
260
+
261
+ <span class="c1"># Make Predictions</span>
262
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Making Predictions...&quot;</span><span class="p">)</span>
263
+ <span class="n">n</span> <span class="o">=</span> <span class="mi">0</span>
264
+ <span class="n">outPathAll</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{outPath}</span><span class="s2">.tsv&quot;</span>
265
+ <span class="n">outPathPos</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{outPath}</span><span class="s2">.positive.tsv&quot;</span>
266
+ <span class="n">cmap_file</span> <span class="o">=</span> <span class="n">h5py</span><span class="o">.</span><span class="n">File</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{outPath}</span><span class="s2">.cmaps.h5&quot;</span><span class="p">,</span> <span class="s2">&quot;w&quot;</span><span class="p">)</span>
267
+ <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">outPathAll</span><span class="p">,</span> <span class="s2">&quot;w+&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
268
+ <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">outPathPos</span><span class="p">,</span> <span class="s2">&quot;w+&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">pos_f</span><span class="p">:</span>
269
+ <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
270
+ <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="p">(</span><span class="n">n0</span><span class="p">,</span> <span class="n">n1</span><span class="p">)</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">pairs</span><span class="o">.</span><span class="n">iloc</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">iterrows</span><span class="p">(),</span> <span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">pairs</span><span class="p">)):</span>
271
+ <span class="n">n0</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">n0</span><span class="p">)</span>
272
+ <span class="n">n1</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">n1</span><span class="p">)</span>
273
+ <span class="k">if</span> <span class="n">n</span> <span class="o">%</span> <span class="mi">50</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
274
+ <span class="n">f</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
275
+ <span class="n">n</span> <span class="o">+=</span> <span class="mi">1</span>
276
+ <span class="n">p0</span> <span class="o">=</span> <span class="n">embeddings</span><span class="p">[</span><span class="n">n0</span><span class="p">]</span>
277
+ <span class="n">p1</span> <span class="o">=</span> <span class="n">embeddings</span><span class="p">[</span><span class="n">n1</span><span class="p">]</span>
278
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
279
+ <span class="n">p0</span> <span class="o">=</span> <span class="n">p0</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
280
+ <span class="n">p1</span> <span class="o">=</span> <span class="n">p1</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
281
+
282
+ <span class="n">cm</span><span class="p">,</span> <span class="n">p</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">map_predict</span><span class="p">(</span><span class="n">p0</span><span class="p">,</span> <span class="n">p1</span><span class="p">)</span>
283
+ <span class="n">p</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
284
+ <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{n0}</span><span class="se">\t</span><span class="si">{n1}</span><span class="se">\t</span><span class="si">{p}</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
285
+ <span class="k">if</span> <span class="n">p</span> <span class="o">&gt;=</span> <span class="n">threshold</span><span class="p">:</span>
286
+ <span class="n">pos_f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{n0}</span><span class="se">\t</span><span class="si">{n1}</span><span class="se">\t</span><span class="si">{p}</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
287
+ <span class="n">cmap_file</span><span class="o">.</span><span class="n">create_dataset</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{n0}</span><span class="s2">x</span><span class="si">{n1}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">data</span><span class="o">=</span><span class="n">cm</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
288
+
289
+ <span class="n">cmap_file</span><span class="o">.</span><span class="n">close</span><span class="p">()</span></div>
290
+
291
+
292
+ <span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
293
+ <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="vm">__doc__</span><span class="p">)</span>
294
+ <span class="n">add_args</span><span class="p">(</span><span class="n">parser</span><span class="p">)</span>
295
+ <span class="n">main</span><span class="p">(</span><span class="n">parser</span><span class="o">.</span><span class="n">parse_args</span><span class="p">())</span>
296
+ </pre></div>
297
+
298
+ </div>
299
+
300
+ </div>
301
+ <footer>
302
+
303
+
304
+ <hr/>
305
+
306
+ <div role="contentinfo">
307
+ <p>
308
+
309
+ &copy; Copyright 2020, Samuel Sledzieski, Rohit Singh
310
+
311
+ </p>
312
+ </div>
313
+
314
+
315
+
316
+ Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a
317
+
318
+ <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a>
319
+
320
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
321
+
322
+ </footer>
323
+
324
+ </div>
325
+ </div>
326
+
327
+ </section>
328
+
329
+ </div>
330
+
331
+
332
+ <script type="text/javascript">
333
+ jQuery(function () {
334
+ SphinxRtdTheme.Navigation.enable(true);
335
+ });
336
+ </script>
337
+
338
+
339
+
340
+
341
+
342
+
343
+ </body>
344
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/commands/train.html ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8" />
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
+
10
+ <title>dscript.commands.train &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+ <!--[if lt IE 9]>
27
+ <script src="../../../_static/js/html5shiv.min.js"></script>
28
+ <![endif]-->
29
+
30
+
31
+ <script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
32
+ <script src="../../../_static/jquery.js"></script>
33
+ <script src="../../../_static/underscore.js"></script>
34
+ <script src="../../../_static/doctools.js"></script>
35
+
36
+ <script type="text/javascript" src="../../../_static/js/theme.js"></script>
37
+
38
+
39
+ <link rel="index" title="Index" href="../../../genindex.html" />
40
+ <link rel="search" title="Search" href="../../../search.html" />
41
+ </head>
42
+
43
+ <body class="wy-body-for-nav">
44
+
45
+
46
+ <div class="wy-grid-for-nav">
47
+
48
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
49
+ <div class="wy-side-scroll">
50
+ <div class="wy-side-nav-search" >
51
+
52
+
53
+
54
+ <a href="../../../index.html" class="icon icon-home"> D-SCRIPT
55
+
56
+
57
+
58
+ </a>
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+ <div role="search">
67
+ <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
68
+ <input type="text" name="q" placeholder="Search docs" />
69
+ <input type="hidden" name="check_keywords" value="yes" />
70
+ <input type="hidden" name="area" value="default" />
71
+ </form>
72
+ </div>
73
+
74
+
75
+ </div>
76
+
77
+
78
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
79
+
80
+
81
+
82
+
83
+
84
+
85
+ <ul>
86
+ <li class="toctree-l1"><a class="reference internal" href="../../../installation.html">Installation</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../../../data.html">Data</a></li>
89
+ <li class="toctree-l1"><a class="reference internal" href="../../../api/index.html">API</a></li>
90
+ </ul>
91
+
92
+
93
+
94
+ </div>
95
+
96
+ </div>
97
+ </nav>
98
+
99
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
100
+
101
+
102
+ <nav class="wy-nav-top" aria-label="top navigation">
103
+
104
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
105
+ <a href="../../../index.html">D-SCRIPT</a>
106
+
107
+ </nav>
108
+
109
+
110
+ <div class="wy-nav-content">
111
+
112
+ <div class="rst-content">
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+ <div role="navigation" aria-label="breadcrumbs navigation">
133
+
134
+ <ul class="wy-breadcrumbs">
135
+
136
+ <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
137
+
138
+ <li><a href="../../index.html">Module code</a> &raquo;</li>
139
+
140
+ <li>dscript.commands.train</li>
141
+
142
+
143
+ <li class="wy-breadcrumbs-aside">
144
+
145
+ </li>
146
+
147
+ </ul>
148
+
149
+
150
+ <hr/>
151
+ </div>
152
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
153
+ <div itemprop="articleBody">
154
+
155
+ <h1>Source code for dscript.commands.train</h1><div class="highlight"><pre>
156
+ <span></span><span class="sd">&quot;&quot;&quot;</span>
157
+ <span class="sd">Train a new model.</span>
158
+ <span class="sd">&quot;&quot;&quot;</span>
159
+
160
+ <span class="kn">import</span> <span class="nn">sys</span>
161
+ <span class="kn">import</span> <span class="nn">argparse</span>
162
+ <span class="kn">import</span> <span class="nn">h5py</span>
163
+ <span class="kn">import</span> <span class="nn">datetime</span>
164
+ <span class="kn">import</span> <span class="nn">subprocess</span> <span class="k">as</span> <span class="nn">sp</span>
165
+ <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
166
+ <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span>
167
+ <span class="kn">import</span> <span class="nn">gzip</span> <span class="k">as</span> <span class="nn">gz</span>
168
+ <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
169
+
170
+ <span class="kn">import</span> <span class="nn">torch</span>
171
+ <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
172
+ <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
173
+ <span class="kn">import</span> <span class="nn">torch.optim</span> <span class="k">as</span> <span class="nn">optim</span>
174
+ <span class="kn">from</span> <span class="nn">torch.autograd</span> <span class="kn">import</span> <span class="n">Variable</span>
175
+ <span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">IterableDataset</span><span class="p">,</span> <span class="n">DataLoader</span>
176
+ <span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">average_precision_score</span> <span class="k">as</span> <span class="n">average_precision</span>
177
+
178
+ <span class="kn">import</span> <span class="nn">dscript</span>
179
+ <span class="kn">from</span> <span class="nn">dscript.utils</span> <span class="kn">import</span> <span class="n">PairedDataset</span><span class="p">,</span> <span class="n">collate_paired_sequences</span>
180
+ <span class="kn">from</span> <span class="nn">dscript.models.embedding</span> <span class="kn">import</span> <span class="p">(</span>
181
+ <span class="n">IdentityEmbed</span><span class="p">,</span>
182
+ <span class="n">FullyConnectedEmbed</span><span class="p">,</span>
183
+ <span class="p">)</span>
184
+ <span class="kn">from</span> <span class="nn">dscript.models.contact</span> <span class="kn">import</span> <span class="n">ContactCNN</span>
185
+ <span class="kn">from</span> <span class="nn">dscript.models.interaction</span> <span class="kn">import</span> <span class="n">ModelInteraction</span>
186
+
187
+
188
+ <span class="k">def</span> <span class="nf">add_args</span><span class="p">(</span><span class="n">parser</span><span class="p">):</span>
189
+ <span class="sd">&quot;&quot;&quot;</span>
190
+ <span class="sd"> Create parser for command line utility.</span>
191
+
192
+ <span class="sd"> :meta private:</span>
193
+ <span class="sd"> &quot;&quot;&quot;</span>
194
+
195
+ <span class="n">data_grp</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument_group</span><span class="p">(</span><span class="s2">&quot;Data&quot;</span><span class="p">)</span>
196
+ <span class="n">proj_grp</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument_group</span><span class="p">(</span><span class="s2">&quot;Projection Module&quot;</span><span class="p">)</span>
197
+ <span class="n">contact_grp</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument_group</span><span class="p">(</span><span class="s2">&quot;Contact Module&quot;</span><span class="p">)</span>
198
+ <span class="n">inter_grp</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument_group</span><span class="p">(</span><span class="s2">&quot;Interaction Module&quot;</span><span class="p">)</span>
199
+ <span class="n">train_grp</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument_group</span><span class="p">(</span><span class="s2">&quot;Training&quot;</span><span class="p">)</span>
200
+ <span class="n">misc_grp</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">add_argument_group</span><span class="p">(</span><span class="s2">&quot;Output and Device&quot;</span><span class="p">)</span>
201
+
202
+ <span class="c1"># Data</span>
203
+ <span class="n">data_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--train&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Training data&quot;</span><span class="p">,</span> <span class="n">required</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
204
+ <span class="n">data_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--val&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Validation data&quot;</span><span class="p">,</span> <span class="n">required</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
205
+ <span class="n">data_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--embedding&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;h5 file with embedded sequences&quot;</span><span class="p">,</span> <span class="n">required</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
206
+ <span class="n">data_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span>
207
+ <span class="s2">&quot;--augment&quot;</span><span class="p">,</span>
208
+ <span class="n">action</span><span class="o">=</span><span class="s2">&quot;store_true&quot;</span><span class="p">,</span>
209
+ <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Set flag to augment data by adding (B A) for all pairs (A B)&quot;</span><span class="p">,</span>
210
+ <span class="p">)</span>
211
+
212
+ <span class="c1"># Embedding model</span>
213
+ <span class="n">proj_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span>
214
+ <span class="s2">&quot;--projection-dim&quot;</span><span class="p">,</span>
215
+ <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span>
216
+ <span class="n">default</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span>
217
+ <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Dimension of embedding projection layer (default: 100)&quot;</span><span class="p">,</span>
218
+ <span class="p">)</span>
219
+ <span class="n">proj_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span>
220
+ <span class="s2">&quot;--dropout-p&quot;</span><span class="p">,</span>
221
+ <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span>
222
+ <span class="n">default</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span>
223
+ <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Parameter p for embedding dropout layer (default: 0.5)&quot;</span><span class="p">,</span>
224
+ <span class="p">)</span>
225
+
226
+ <span class="c1"># Contact model</span>
227
+ <span class="n">contact_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span>
228
+ <span class="s2">&quot;--hidden-dim&quot;</span><span class="p">,</span>
229
+ <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span>
230
+ <span class="n">default</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span>
231
+ <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Number of hidden units for comparison layer in contact prediction (default: 50)&quot;</span><span class="p">,</span>
232
+ <span class="p">)</span>
233
+ <span class="n">contact_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span>
234
+ <span class="s2">&quot;--kernel-width&quot;</span><span class="p">,</span>
235
+ <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span>
236
+ <span class="n">default</span><span class="o">=</span><span class="mi">7</span><span class="p">,</span>
237
+ <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Width of convolutional filter for contact prediction (default: 7)&quot;</span><span class="p">,</span>
238
+ <span class="p">)</span>
239
+
240
+ <span class="c1"># Interaction Model</span>
241
+ <span class="n">inter_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span>
242
+ <span class="s2">&quot;--use-w&quot;</span><span class="p">,</span>
243
+ <span class="n">action</span><span class="o">=</span><span class="s2">&quot;store_true&quot;</span><span class="p">,</span>
244
+ <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Use weight matrix in interaction prediction model&quot;</span><span class="p">,</span>
245
+ <span class="p">)</span>
246
+ <span class="n">inter_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span>
247
+ <span class="s2">&quot;--pool-width&quot;</span><span class="p">,</span>
248
+ <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span>
249
+ <span class="n">default</span><span class="o">=</span><span class="mi">9</span><span class="p">,</span>
250
+ <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Size of max-pool in interaction model (default: 9)&quot;</span><span class="p">,</span>
251
+ <span class="p">)</span>
252
+
253
+ <span class="c1"># Training</span>
254
+ <span class="n">train_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span>
255
+ <span class="s2">&quot;--negative-ratio&quot;</span><span class="p">,</span>
256
+ <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span>
257
+ <span class="n">default</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
258
+ <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Number of negative training samples for each positive training sample (default: 10)&quot;</span><span class="p">,</span>
259
+ <span class="p">)</span>
260
+ <span class="n">train_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span>
261
+ <span class="s2">&quot;--epoch-scale&quot;</span><span class="p">,</span>
262
+ <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span>
263
+ <span class="n">default</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
264
+ <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Report heldout performance every this many epochs (default: 5)&quot;</span><span class="p">,</span>
265
+ <span class="p">)</span>
266
+ <span class="n">train_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--num-epochs&quot;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Number of epochs (default: 100)&quot;</span><span class="p">)</span>
267
+ <span class="n">train_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--batch-size&quot;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">25</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Minibatch size (default: 25)&quot;</span><span class="p">)</span>
268
+ <span class="n">train_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--weight-decay&quot;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;L2 regularization (default: 0)&quot;</span><span class="p">)</span>
269
+ <span class="n">train_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--lr&quot;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="mf">0.001</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Learning rate (default: 0.001)&quot;</span><span class="p">)</span>
270
+ <span class="n">train_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span>
271
+ <span class="s2">&quot;--lambda&quot;</span><span class="p">,</span>
272
+ <span class="n">dest</span><span class="o">=</span><span class="s2">&quot;lambda_&quot;</span><span class="p">,</span>
273
+ <span class="nb">type</span><span class="o">=</span><span class="nb">float</span><span class="p">,</span>
274
+ <span class="n">default</span><span class="o">=</span><span class="mf">0.35</span><span class="p">,</span>
275
+ <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Weight on the similarity objective (default: 0.35)&quot;</span><span class="p">,</span>
276
+ <span class="p">)</span>
277
+
278
+ <span class="c1"># Output</span>
279
+ <span class="n">misc_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;-o&quot;</span><span class="p">,</span> <span class="s2">&quot;--outfile&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Output file path (default: stdout)&quot;</span><span class="p">)</span>
280
+ <span class="n">misc_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--save-prefix&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Path prefix for saving models&quot;</span><span class="p">)</span>
281
+ <span class="n">misc_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;-d&quot;</span><span class="p">,</span> <span class="s2">&quot;--device&quot;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Compute device to use&quot;</span><span class="p">)</span>
282
+ <span class="n">misc_grp</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--checkpoint&quot;</span><span class="p">,</span> <span class="n">help</span><span class="o">=</span><span class="s2">&quot;Checkpoint model to start training from&quot;</span><span class="p">)</span>
283
+
284
+ <span class="k">return</span> <span class="n">parser</span>
285
+
286
+
287
+ <div class="viewcode-block" id="predict_interaction"><a class="viewcode-back" href="../../../api/dscript.commands.html#dscript.commands.train.predict_interaction">[docs]</a><span class="k">def</span> <span class="nf">predict_interaction</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">n0</span><span class="p">,</span> <span class="n">n1</span><span class="p">,</span> <span class="n">tensors</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">):</span>
288
+ <span class="sd">&quot;&quot;&quot;</span>
289
+ <span class="sd"> Predict whether a list of protein pairs will interact.</span>
290
+
291
+ <span class="sd"> :param model: Model to be trained</span>
292
+ <span class="sd"> :type model: dscript.models.interaction.ModelInteraction</span>
293
+ <span class="sd"> :param n0: First protein names</span>
294
+ <span class="sd"> :type n0: list[str]</span>
295
+ <span class="sd"> :param n1: Second protein names</span>
296
+ <span class="sd"> :type n1: list[str]</span>
297
+ <span class="sd"> :param tensors: Dictionary of protein names to embeddings</span>
298
+ <span class="sd"> :type tensors: dict[str, torch.Tensor]</span>
299
+ <span class="sd"> :param use_cuda: Whether to use GPU</span>
300
+ <span class="sd"> :type use_cuda: bool</span>
301
+ <span class="sd"> &quot;&quot;&quot;</span>
302
+
303
+ <span class="n">b</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">n0</span><span class="p">)</span>
304
+
305
+ <span class="n">p_hat</span> <span class="o">=</span> <span class="p">[]</span>
306
+ <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">b</span><span class="p">):</span>
307
+ <span class="n">z_a</span> <span class="o">=</span> <span class="n">tensors</span><span class="p">[</span><span class="n">n0</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
308
+ <span class="n">z_b</span> <span class="o">=</span> <span class="n">tensors</span><span class="p">[</span><span class="n">n1</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
309
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
310
+ <span class="n">z_a</span> <span class="o">=</span> <span class="n">z_a</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
311
+ <span class="n">z_b</span> <span class="o">=</span> <span class="n">z_b</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
312
+
313
+ <span class="n">p_hat</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">z_a</span><span class="p">,</span> <span class="n">z_b</span><span class="p">))</span>
314
+ <span class="n">p_hat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">p_hat</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
315
+ <span class="k">return</span> <span class="n">p_hat</span></div>
316
+
317
+
318
+ <div class="viewcode-block" id="predict_cmap_interaction"><a class="viewcode-back" href="../../../api/dscript.commands.html#dscript.commands.train.predict_cmap_interaction">[docs]</a><span class="k">def</span> <span class="nf">predict_cmap_interaction</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">n0</span><span class="p">,</span> <span class="n">n1</span><span class="p">,</span> <span class="n">tensors</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">):</span>
319
+ <span class="sd">&quot;&quot;&quot;</span>
320
+ <span class="sd"> Predict whether a list of protein pairs will interact, as well as their contact map.</span>
321
+
322
+ <span class="sd"> :param model: Model to be trained</span>
323
+ <span class="sd"> :type model: dscript.models.interaction.ModelInteraction</span>
324
+ <span class="sd"> :param n0: First protein names</span>
325
+ <span class="sd"> :type n0: list[str]</span>
326
+ <span class="sd"> :param n1: Second protein names</span>
327
+ <span class="sd"> :type n1: list[str]</span>
328
+ <span class="sd"> :param tensors: Dictionary of protein names to embeddings</span>
329
+ <span class="sd"> :type tensors: dict[str, torch.Tensor]</span>
330
+ <span class="sd"> :param use_cuda: Whether to use GPU</span>
331
+ <span class="sd"> :type use_cuda: bool</span>
332
+ <span class="sd"> &quot;&quot;&quot;</span>
333
+
334
+ <span class="n">b</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">n0</span><span class="p">)</span>
335
+
336
+ <span class="n">p_hat</span> <span class="o">=</span> <span class="p">[]</span>
337
+ <span class="n">c_map_mag</span> <span class="o">=</span> <span class="p">[]</span>
338
+ <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">b</span><span class="p">):</span>
339
+ <span class="n">z_a</span> <span class="o">=</span> <span class="n">tensors</span><span class="p">[</span><span class="n">n0</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
340
+ <span class="n">z_b</span> <span class="o">=</span> <span class="n">tensors</span><span class="p">[</span><span class="n">n1</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
341
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
342
+ <span class="n">z_a</span> <span class="o">=</span> <span class="n">z_a</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
343
+ <span class="n">z_b</span> <span class="o">=</span> <span class="n">z_b</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
344
+
345
+ <span class="n">cm</span><span class="p">,</span> <span class="n">ph</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">map_predict</span><span class="p">(</span><span class="n">z_a</span><span class="p">,</span> <span class="n">z_b</span><span class="p">)</span>
346
+ <span class="n">p_hat</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ph</span><span class="p">)</span>
347
+ <span class="n">c_map_mag</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">cm</span><span class="p">))</span>
348
+ <span class="n">p_hat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">p_hat</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
349
+ <span class="n">c_map_mag</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">c_map_mag</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
350
+ <span class="k">return</span> <span class="n">c_map_mag</span><span class="p">,</span> <span class="n">p_hat</span></div>
351
+
352
+
353
+ <div class="viewcode-block" id="interaction_grad"><a class="viewcode-back" href="../../../api/dscript.commands.html#dscript.commands.train.interaction_grad">[docs]</a><span class="k">def</span> <span class="nf">interaction_grad</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">n0</span><span class="p">,</span> <span class="n">n1</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">tensors</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="mf">0.35</span><span class="p">):</span>
354
+ <span class="sd">&quot;&quot;&quot;</span>
355
+ <span class="sd"> Compute gradient and backpropagate loss for a batch.</span>
356
+
357
+ <span class="sd"> :param model: Model to be trained</span>
358
+ <span class="sd"> :type model: dscript.models.interaction.ModelInteraction</span>
359
+ <span class="sd"> :param n0: First protein names</span>
360
+ <span class="sd"> :type n0: list[str]</span>
361
+ <span class="sd"> :param n1: Second protein names</span>
362
+ <span class="sd"> :type n1: list[str]</span>
363
+ <span class="sd"> :param y: Interaction labels</span>
364
+ <span class="sd"> :type y: torch.Tensor</span>
365
+ <span class="sd"> :param tensors: Dictionary of protein names to embeddings</span>
366
+ <span class="sd"> :type tensors: dict[str, torch.Tensor]</span>
367
+ <span class="sd"> :param use_cuda: Whether to use GPU</span>
368
+ <span class="sd"> :type use_cuda: bool</span>
369
+ <span class="sd"> :param weight: Weight on the contact map magnitude objective. BCE loss is :math:`1 - \\text{weight}`.</span>
370
+ <span class="sd"> :type weight: float</span>
371
+
372
+ <span class="sd"> :return: (Loss, number correct, mean square error, batch size)</span>
373
+ <span class="sd"> :rtype: (torch.Tensor, int, torch.Tensor, int)</span>
374
+ <span class="sd"> &quot;&quot;&quot;</span>
375
+
376
+ <span class="n">c_map_mag</span><span class="p">,</span> <span class="n">p_hat</span> <span class="o">=</span> <span class="n">predict_cmap_interaction</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">n0</span><span class="p">,</span> <span class="n">n1</span><span class="p">,</span> <span class="n">tensors</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">)</span>
377
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
378
+ <span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
379
+ <span class="n">y</span> <span class="o">=</span> <span class="n">Variable</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
380
+
381
+ <span class="n">bce_loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">p_hat</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">float</span><span class="p">())</span>
382
+ <span class="n">cmap_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">c_map_mag</span><span class="p">)</span>
383
+ <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span> <span class="o">*</span> <span class="n">bce_loss</span><span class="p">)</span> <span class="o">+</span> <span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="n">weight</span><span class="p">)</span> <span class="o">*</span> <span class="n">cmap_loss</span><span class="p">)</span>
384
+ <span class="n">b</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">p_hat</span><span class="p">)</span>
385
+
386
+ <span class="c1"># backprop loss</span>
387
+ <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
388
+
389
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
390
+ <span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
391
+ <span class="n">p_hat</span> <span class="o">=</span> <span class="n">p_hat</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
392
+
393
+ <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
394
+ <span class="n">guess_cutoff</span> <span class="o">=</span> <span class="mf">0.5</span>
395
+ <span class="n">p_hat</span> <span class="o">=</span> <span class="n">p_hat</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
396
+ <span class="n">p_guess</span> <span class="o">=</span> <span class="p">(</span><span class="n">guess_cutoff</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">b</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">p_hat</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
397
+ <span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
398
+ <span class="n">correct</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p_guess</span> <span class="o">==</span> <span class="n">y</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
399
+ <span class="n">mse</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">((</span><span class="n">y</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">-</span> <span class="n">p_hat</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
400
+
401
+ <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">correct</span><span class="p">,</span> <span class="n">mse</span><span class="p">,</span> <span class="n">b</span></div>
402
+
403
+
404
+ <div class="viewcode-block" id="interaction_eval"><a class="viewcode-back" href="../../../api/dscript.commands.html#dscript.commands.train.interaction_eval">[docs]</a><span class="k">def</span> <span class="nf">interaction_eval</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">test_iterator</span><span class="p">,</span> <span class="n">tensors</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">):</span>
405
+ <span class="sd">&quot;&quot;&quot;</span>
406
+ <span class="sd"> Evaluate test data set performance.</span>
407
+
408
+ <span class="sd"> :param model: Model to be trained</span>
409
+ <span class="sd"> :type model: dscript.models.interaction.ModelInteraction</span>
410
+ <span class="sd"> :param test_iterator: Test data iterator</span>
411
+ <span class="sd"> :type test_iterator: torch.utils.data.DataLoader</span>
412
+ <span class="sd"> :param tensors: Dictionary of protein names to embeddings</span>
413
+ <span class="sd"> :type tensors: dict[str, torch.Tensor]</span>
414
+ <span class="sd"> :param use_cuda: Whether to use GPU</span>
415
+ <span class="sd"> :type use_cuda: bool</span>
416
+
417
+ <span class="sd"> :return: (Loss, number correct, mean square error, precision, recall, F1 Score, AUPR)</span>
418
+ <span class="sd"> :rtype: (torch.Tensor, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)</span>
419
+ <span class="sd"> &quot;&quot;&quot;</span>
420
+ <span class="n">p_hat</span> <span class="o">=</span> <span class="p">[]</span>
421
+ <span class="n">true_y</span> <span class="o">=</span> <span class="p">[]</span>
422
+
423
+ <span class="k">for</span> <span class="n">n0</span><span class="p">,</span> <span class="n">n1</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">test_iterator</span><span class="p">:</span>
424
+ <span class="n">p_hat</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">predict_interaction</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">n0</span><span class="p">,</span> <span class="n">n1</span><span class="p">,</span> <span class="n">tensors</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">))</span>
425
+ <span class="n">true_y</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
426
+
427
+ <span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">true_y</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
428
+ <span class="n">p_hat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">p_hat</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
429
+
430
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
431
+ <span class="n">y</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
432
+ <span class="n">p_hat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="n">x</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">p_hat</span><span class="p">])</span>
433
+ <span class="n">p_hat</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
434
+
435
+ <span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">p_hat</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">y</span><span class="o">.</span><span class="n">float</span><span class="p">())</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
436
+ <span class="n">b</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
437
+
438
+ <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
439
+ <span class="n">guess_cutoff</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="mf">0.5</span><span class="p">])</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
440
+ <span class="n">p_hat</span> <span class="o">=</span> <span class="n">p_hat</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
441
+ <span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
442
+ <span class="n">p_guess</span> <span class="o">=</span> <span class="p">(</span><span class="n">guess_cutoff</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">b</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">p_hat</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
443
+ <span class="n">correct</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p_guess</span> <span class="o">==</span> <span class="n">y</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
444
+ <span class="n">mse</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">((</span><span class="n">y</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">-</span> <span class="n">p_hat</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
445
+
446
+ <span class="n">tp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">y</span> <span class="o">*</span> <span class="n">p_hat</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
447
+ <span class="n">pr</span> <span class="o">=</span> <span class="n">tp</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">p_hat</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
448
+ <span class="n">re</span> <span class="o">=</span> <span class="n">tp</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">y</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
449
+ <span class="n">f1</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">pr</span> <span class="o">*</span> <span class="n">re</span> <span class="o">/</span> <span class="p">(</span><span class="n">pr</span> <span class="o">+</span> <span class="n">re</span><span class="p">)</span>
450
+
451
+ <span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
452
+ <span class="n">p_hat</span> <span class="o">=</span> <span class="n">p_hat</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
453
+
454
+ <span class="n">aupr</span> <span class="o">=</span> <span class="n">average_precision</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">p_hat</span><span class="p">)</span>
455
+
456
+ <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">correct</span><span class="p">,</span> <span class="n">mse</span><span class="p">,</span> <span class="n">pr</span><span class="p">,</span> <span class="n">re</span><span class="p">,</span> <span class="n">f1</span><span class="p">,</span> <span class="n">aupr</span></div>
457
+
458
+
459
+ <span class="k">def</span> <span class="nf">main</span><span class="p">(</span><span class="n">args</span><span class="p">):</span>
460
+ <span class="sd">&quot;&quot;&quot;</span>
461
+ <span class="sd"> Run training from arguments.</span>
462
+
463
+ <span class="sd"> :meta private:</span>
464
+ <span class="sd"> &quot;&quot;&quot;</span>
465
+
466
+ <span class="n">output</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">outfile</span>
467
+ <span class="k">if</span> <span class="n">output</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
468
+ <span class="n">output</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span>
469
+ <span class="k">else</span><span class="p">:</span>
470
+ <span class="n">output</span> <span class="o">=</span> <span class="nb">open</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="s2">&quot;w&quot;</span><span class="p">)</span>
471
+
472
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;# Called as: </span><span class="si">{</span><span class="s2">&quot; &quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sys</span><span class="o">.</span><span class="n">argv</span><span class="p">)</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
473
+ <span class="k">if</span> <span class="n">output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="p">:</span>
474
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Called as: </span><span class="si">{</span><span class="s2">&quot; &quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sys</span><span class="o">.</span><span class="n">argv</span><span class="p">)</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
475
+
476
+ <span class="c1"># Set device</span>
477
+ <span class="n">device</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">device</span>
478
+ <span class="n">use_cuda</span> <span class="o">=</span> <span class="p">(</span><span class="n">device</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span>
479
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
480
+ <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
481
+ <span class="nb">print</span><span class="p">(</span>
482
+ <span class="sa">f</span><span class="s2">&quot;# Using CUDA device </span><span class="si">{</span><span class="n">device</span><span class="si">}</span><span class="s2"> - </span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">get_device_name</span><span class="p">(</span><span class="n">device</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
483
+ <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">,</span>
484
+ <span class="p">)</span>
485
+ <span class="k">else</span><span class="p">:</span>
486
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Using CPU&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
487
+ <span class="n">device</span> <span class="o">=</span> <span class="s2">&quot;cpu&quot;</span>
488
+
489
+ <span class="n">batch_size</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">batch_size</span>
490
+
491
+ <span class="n">train_fi</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">train</span>
492
+ <span class="n">test_fi</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">val</span>
493
+ <span class="n">augment</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">augment</span>
494
+ <span class="n">embedding_h5</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">embedding</span>
495
+ <span class="n">h5fi</span> <span class="o">=</span> <span class="n">h5py</span><span class="o">.</span><span class="n">File</span><span class="p">(</span><span class="n">embedding_h5</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">)</span>
496
+
497
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;# Loading training pairs from </span><span class="si">{</span><span class="n">train_fi</span><span class="si">}</span><span class="s2">...&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
498
+ <span class="n">output</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
499
+
500
+ <span class="n">train_df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">train_fi</span><span class="p">,</span> <span class="n">sep</span><span class="o">=</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">header</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
501
+ <span class="k">if</span> <span class="n">augment</span><span class="p">:</span>
502
+ <span class="n">train_n0</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">concat</span><span class="p">((</span><span class="n">train_df</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">train_df</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
503
+ <span class="n">train_n1</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">concat</span><span class="p">((</span><span class="n">train_df</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">train_df</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
504
+ <span class="n">train_y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">pd</span><span class="o">.</span><span class="n">concat</span><span class="p">((</span><span class="n">train_df</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">train_df</span><span class="p">[</span><span class="mi">2</span><span class="p">]))</span><span class="o">.</span><span class="n">values</span><span class="p">)</span>
505
+ <span class="k">else</span><span class="p">:</span>
506
+ <span class="n">train_n0</span><span class="p">,</span> <span class="n">train_n1</span> <span class="o">=</span> <span class="n">train_df</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">train_df</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
507
+ <span class="n">train_y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">train_df</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">)</span>
508
+
509
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;# Loading testing pairs from </span><span class="si">{</span><span class="n">test_fi</span><span class="si">}</span><span class="s2">...&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
510
+ <span class="n">output</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
511
+
512
+ <span class="n">test_df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">test_fi</span><span class="p">,</span> <span class="n">sep</span><span class="o">=</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">header</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
513
+ <span class="k">if</span> <span class="n">augment</span><span class="p">:</span>
514
+ <span class="n">test_n0</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">concat</span><span class="p">((</span><span class="n">test_df</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">test_df</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
515
+ <span class="n">test_n1</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">concat</span><span class="p">((</span><span class="n">test_df</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">test_df</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
516
+ <span class="n">test_y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">pd</span><span class="o">.</span><span class="n">concat</span><span class="p">((</span><span class="n">test_df</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">test_df</span><span class="p">[</span><span class="mi">2</span><span class="p">]))</span><span class="o">.</span><span class="n">values</span><span class="p">)</span>
517
+ <span class="k">else</span><span class="p">:</span>
518
+ <span class="n">test_n0</span><span class="p">,</span> <span class="n">test_n1</span> <span class="o">=</span> <span class="n">test_df</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">test_df</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
519
+ <span class="n">test_y</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">test_df</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">)</span>
520
+ <span class="n">output</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
521
+
522
+ <span class="n">train_pairs</span> <span class="o">=</span> <span class="n">PairedDataset</span><span class="p">(</span><span class="n">train_n0</span><span class="p">,</span> <span class="n">train_n1</span><span class="p">,</span> <span class="n">train_y</span><span class="p">)</span>
523
+ <span class="n">pairs_train_iterator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span>
524
+ <span class="n">train_pairs</span><span class="p">,</span>
525
+ <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
526
+ <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_paired_sequences</span><span class="p">,</span>
527
+ <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
528
+ <span class="p">)</span>
529
+
530
+ <span class="n">test_pairs</span> <span class="o">=</span> <span class="n">PairedDataset</span><span class="p">(</span><span class="n">test_n0</span><span class="p">,</span> <span class="n">test_n1</span><span class="p">,</span> <span class="n">test_y</span><span class="p">)</span>
531
+ <span class="n">pairs_test_iterator</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span>
532
+ <span class="n">test_pairs</span><span class="p">,</span>
533
+ <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
534
+ <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_paired_sequences</span><span class="p">,</span>
535
+ <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
536
+ <span class="p">)</span>
537
+
538
+ <span class="n">output</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
539
+
540
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;# Loading embeddings&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
541
+ <span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span>
542
+ <span class="n">all_proteins</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">train_n0</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">train_n1</span><span class="p">))</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">test_n0</span><span class="p">))</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">test_n1</span><span class="p">))</span>
543
+ <span class="k">for</span> <span class="n">prot_name</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">all_proteins</span><span class="p">):</span>
544
+ <span class="n">tensors</span><span class="p">[</span><span class="n">prot_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">h5fi</span><span class="p">[</span><span class="n">prot_name</span><span class="p">][:,</span> <span class="p">:])</span>
545
+
546
+ <span class="n">use_cuda</span> <span class="o">=</span> <span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">device</span> <span class="o">&gt;</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span>
547
+
548
+ <span class="k">if</span> <span class="n">args</span><span class="o">.</span><span class="n">checkpoint</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
549
+
550
+ <span class="n">projection_dim</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">projection_dim</span>
551
+ <span class="n">dropout_p</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">dropout_p</span>
552
+ <span class="n">embedding</span> <span class="o">=</span> <span class="n">FullyConnectedEmbed</span><span class="p">(</span><span class="mi">6165</span><span class="p">,</span> <span class="n">projection_dim</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="n">dropout_p</span><span class="p">)</span>
553
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Initializing embedding model with:&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
554
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">projection_dim: </span><span class="si">{</span><span class="n">projection_dim</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
555
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">dropout_p: </span><span class="si">{</span><span class="n">dropout_p</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
556
+
557
+ <span class="c1"># Create contact model</span>
558
+ <span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">hidden_dim</span>
559
+ <span class="n">kernel_width</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">kernel_width</span>
560
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Initializing contact model with:&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
561
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">hidden_dim: </span><span class="si">{</span><span class="n">hidden_dim</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
562
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">kernel_width: </span><span class="si">{</span><span class="n">kernel_width</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
563
+
564
+ <span class="n">contact</span> <span class="o">=</span> <span class="n">ContactCNN</span><span class="p">(</span><span class="n">projection_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">,</span> <span class="n">kernel_width</span><span class="p">)</span>
565
+
566
+ <span class="c1"># Create the full model</span>
567
+ <span class="n">use_W</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">use_w</span>
568
+ <span class="n">pool_width</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">pool_width</span>
569
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Initializing interaction model with:&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
570
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">pool_width: </span><span class="si">{</span><span class="n">pool_width</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
571
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">use_w: </span><span class="si">{</span><span class="n">use_W</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
572
+ <span class="n">model</span> <span class="o">=</span> <span class="n">ModelInteraction</span><span class="p">(</span><span class="n">embedding</span><span class="p">,</span> <span class="n">contact</span><span class="p">,</span> <span class="n">use_W</span><span class="o">=</span><span class="n">use_W</span><span class="p">,</span> <span class="n">pool_size</span><span class="o">=</span><span class="n">pool_width</span><span class="p">)</span>
573
+
574
+ <span class="nb">print</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
575
+
576
+ <span class="k">else</span><span class="p">:</span>
577
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Loading model from checkpoint </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">checkpoint</span><span class="p">),</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
578
+ <span class="n">model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">args</span><span class="o">.</span><span class="n">checkpoint</span><span class="p">)</span>
579
+ <span class="n">model</span><span class="o">.</span><span class="n">use_cuda</span> <span class="o">=</span> <span class="n">use_cuda</span>
580
+
581
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
582
+ <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
583
+
584
+ <span class="c1"># Train the model</span>
585
+ <span class="n">lr</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">lr</span>
586
+ <span class="n">wd</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">weight_decay</span>
587
+ <span class="n">num_epochs</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">num_epochs</span>
588
+ <span class="n">batch_size</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">batch_size</span>
589
+ <span class="n">report_steps</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">epoch_scale</span>
590
+ <span class="n">inter_weight</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">lambda_</span>
591
+ <span class="n">cmap_weight</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">inter_weight</span>
592
+ <span class="n">digits</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log10</span><span class="p">(</span><span class="n">num_epochs</span><span class="p">)))</span> <span class="o">+</span> <span class="mi">1</span>
593
+ <span class="n">save_prefix</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">save_prefix</span>
594
+ <span class="k">if</span> <span class="n">save_prefix</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
595
+ <span class="n">save_prefix</span> <span class="o">=</span> <span class="n">datetime</span><span class="o">.</span><span class="n">datetime</span><span class="o">.</span><span class="n">now</span><span class="p">()</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s2">&quot;%Y-%m-</span><span class="si">%d</span><span class="s2">-%H-%M&quot;</span><span class="p">)</span>
596
+
597
+ <span class="n">params</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">()</span> <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">]</span>
598
+ <span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">wd</span><span class="p">)</span>
599
+
600
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;# Using save prefix &quot;</span><span class="si">{</span><span class="n">save_prefix</span><span class="si">}</span><span class="s1">&quot;&#39;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
601
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;# Training with Adam: lr=</span><span class="si">{</span><span class="n">lr</span><span class="si">}</span><span class="s2">, weight_decay=</span><span class="si">{</span><span class="n">wd</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
602
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">num_epochs: </span><span class="si">{</span><span class="n">num_epochs</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
603
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">epoch_scale: </span><span class="si">{</span><span class="n">report_steps</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
604
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">batch_size: </span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
605
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">interaction weight: </span><span class="si">{</span><span class="n">inter_weight</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
606
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\t</span><span class="s2">contact map weight: </span><span class="si">{</span><span class="n">cmap_weight</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
607
+ <span class="n">output</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
608
+
609
+ <span class="n">batch_report_fmt</span> <span class="o">=</span> <span class="s2">&quot;# [</span><span class="si">{}</span><span class="s2">/</span><span class="si">{}</span><span class="s2">] training </span><span class="si">{:.1%}</span><span class="s2">: Loss=</span><span class="si">{:.6}</span><span class="s2">, Accuracy=</span><span class="si">{:.3%}</span><span class="s2">, MSE=</span><span class="si">{:.6}</span><span class="s2">&quot;</span>
610
+ <span class="n">epoch_report_fmt</span> <span class="o">=</span> <span class="s2">&quot;# Finished Epoch </span><span class="si">{}</span><span class="s2">/</span><span class="si">{}</span><span class="s2">: Loss=</span><span class="si">{:.6}</span><span class="s2">, Accuracy=</span><span class="si">{:.3%}</span><span class="s2">, MSE=</span><span class="si">{:.6}</span><span class="s2">, Precision=</span><span class="si">{:.6}</span><span class="s2">, Recall=</span><span class="si">{:.6}</span><span class="s2">, F1=</span><span class="si">{:.6}</span><span class="s2">, AUPR=</span><span class="si">{:.6}</span><span class="s2">&quot;</span>
611
+
612
+ <span class="n">N</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">pairs_train_iterator</span><span class="p">)</span> <span class="o">*</span> <span class="n">batch_size</span>
613
+ <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_epochs</span><span class="p">):</span>
614
+
615
+ <span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
616
+
617
+ <span class="n">n</span> <span class="o">=</span> <span class="mi">0</span>
618
+ <span class="n">loss_accum</span> <span class="o">=</span> <span class="mi">0</span>
619
+ <span class="n">acc_accum</span> <span class="o">=</span> <span class="mi">0</span>
620
+ <span class="n">mse_accum</span> <span class="o">=</span> <span class="mi">0</span>
621
+
622
+ <span class="c1"># Train batches</span>
623
+ <span class="k">for</span> <span class="p">(</span><span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">pairs_train_iterator</span><span class="p">,</span> <span class="n">desc</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2">/</span><span class="si">{</span><span class="n">num_epochs</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span><span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">pairs_train_iterator</span><span class="p">)):</span>
624
+
625
+ <span class="n">loss</span><span class="p">,</span> <span class="n">correct</span><span class="p">,</span> <span class="n">mse</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">interaction_grad</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">tensors</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="n">inter_weight</span><span class="p">)</span>
626
+
627
+ <span class="n">n</span> <span class="o">+=</span> <span class="n">b</span>
628
+ <span class="n">delta</span> <span class="o">=</span> <span class="n">b</span> <span class="o">*</span> <span class="p">(</span><span class="n">loss</span> <span class="o">-</span> <span class="n">loss_accum</span><span class="p">)</span>
629
+ <span class="n">loss_accum</span> <span class="o">+=</span> <span class="n">delta</span> <span class="o">/</span> <span class="n">n</span>
630
+
631
+ <span class="n">delta</span> <span class="o">=</span> <span class="n">correct</span> <span class="o">-</span> <span class="n">b</span> <span class="o">*</span> <span class="n">acc_accum</span>
632
+ <span class="n">acc_accum</span> <span class="o">+=</span> <span class="n">delta</span> <span class="o">/</span> <span class="n">n</span>
633
+
634
+ <span class="n">delta</span> <span class="o">=</span> <span class="n">b</span> <span class="o">*</span> <span class="p">(</span><span class="n">mse</span> <span class="o">-</span> <span class="n">mse_accum</span><span class="p">)</span>
635
+ <span class="n">mse_accum</span> <span class="o">+=</span> <span class="n">delta</span> <span class="o">/</span> <span class="n">n</span>
636
+
637
+ <span class="n">report</span> <span class="o">=</span> <span class="p">(</span><span class="n">n</span> <span class="o">-</span> <span class="n">b</span><span class="p">)</span> <span class="o">//</span> <span class="mi">100</span> <span class="o">&lt;</span> <span class="n">n</span> <span class="o">//</span> <span class="mi">100</span>
638
+
639
+ <span class="n">optim</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
640
+ <span class="n">optim</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
641
+ <span class="n">model</span><span class="o">.</span><span class="n">clip</span><span class="p">()</span>
642
+
643
+ <span class="k">if</span> <span class="n">report</span><span class="p">:</span>
644
+ <span class="n">tokens</span> <span class="o">=</span> <span class="p">[</span>
645
+ <span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
646
+ <span class="n">num_epochs</span><span class="p">,</span>
647
+ <span class="n">n</span> <span class="o">/</span> <span class="n">N</span><span class="p">,</span>
648
+ <span class="n">loss_accum</span><span class="p">,</span>
649
+ <span class="n">acc_accum</span><span class="p">,</span>
650
+ <span class="n">mse_accum</span><span class="p">,</span>
651
+ <span class="p">]</span>
652
+ <span class="k">if</span> <span class="n">output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span><span class="p">:</span>
653
+ <span class="nb">print</span><span class="p">(</span><span class="n">batch_report_fmt</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="o">*</span><span class="n">tokens</span><span class="p">),</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
654
+ <span class="n">output</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
655
+
656
+ <span class="k">if</span> <span class="p">(</span><span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">report_steps</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
657
+ <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
658
+
659
+ <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
660
+
661
+ <span class="p">(</span>
662
+ <span class="n">inter_loss</span><span class="p">,</span>
663
+ <span class="n">inter_correct</span><span class="p">,</span>
664
+ <span class="n">inter_mse</span><span class="p">,</span>
665
+ <span class="n">inter_pr</span><span class="p">,</span>
666
+ <span class="n">inter_re</span><span class="p">,</span>
667
+ <span class="n">inter_f1</span><span class="p">,</span>
668
+ <span class="n">inter_aupr</span><span class="p">,</span>
669
+ <span class="p">)</span> <span class="o">=</span> <span class="n">interaction_eval</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">pairs_test_iterator</span><span class="p">,</span> <span class="n">tensors</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">)</span>
670
+ <span class="n">tokens</span> <span class="o">=</span> <span class="p">[</span>
671
+ <span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
672
+ <span class="n">num_epochs</span><span class="p">,</span>
673
+ <span class="n">inter_loss</span><span class="p">,</span>
674
+ <span class="n">inter_correct</span> <span class="o">/</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">pairs_test_iterator</span><span class="p">)</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">),</span>
675
+ <span class="n">inter_mse</span><span class="p">,</span>
676
+ <span class="n">inter_pr</span><span class="p">,</span>
677
+ <span class="n">inter_re</span><span class="p">,</span>
678
+ <span class="n">inter_f1</span><span class="p">,</span>
679
+ <span class="n">inter_aupr</span><span class="p">,</span>
680
+ <span class="p">]</span>
681
+ <span class="nb">print</span><span class="p">(</span><span class="n">epoch_report_fmt</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="o">*</span><span class="n">tokens</span><span class="p">),</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
682
+ <span class="n">output</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
683
+
684
+ <span class="c1"># Save the model</span>
685
+ <span class="k">if</span> <span class="n">save_prefix</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
686
+ <span class="n">save_path</span> <span class="o">=</span> <span class="n">save_prefix</span> <span class="o">+</span> <span class="s2">&quot;_epoch&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">zfill</span><span class="p">(</span><span class="n">digits</span><span class="p">)</span> <span class="o">+</span> <span class="s2">&quot;.sav&quot;</span>
687
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;# Saving model to </span><span class="si">{</span><span class="n">save_path</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
688
+ <span class="n">model</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
689
+ <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">save_path</span><span class="p">)</span>
690
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
691
+ <span class="n">model</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
692
+
693
+ <span class="n">output</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
694
+
695
+ <span class="k">if</span> <span class="n">save_prefix</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
696
+ <span class="n">save_path</span> <span class="o">=</span> <span class="n">save_prefix</span> <span class="o">+</span> <span class="s2">&quot;_final.sav&quot;</span>
697
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;# Saving final model to </span><span class="si">{</span><span class="n">save_path</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">output</span><span class="p">)</span>
698
+ <span class="n">model</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
699
+ <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">save_path</span><span class="p">)</span>
700
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
701
+ <span class="n">model</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
702
+
703
+ <span class="n">output</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
704
+
705
+
706
+ <span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
707
+ <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="vm">__doc__</span><span class="p">)</span>
708
+ <span class="n">add_args</span><span class="p">(</span><span class="n">parser</span><span class="p">)</span>
709
+ <span class="n">main</span><span class="p">(</span><span class="n">parser</span><span class="o">.</span><span class="n">parse_args</span><span class="p">())</span>
710
+ </pre></div>
711
+
712
+ </div>
713
+
714
+ </div>
715
+ <footer>
716
+
717
+ <hr/>
718
+
719
+ <div role="contentinfo">
720
+ <p>
721
+ &#169; Copyright 2020, Samuel Sledzieski, Rohit Singh.
722
+
723
+ </p>
724
+ </div>
725
+
726
+
727
+
728
+ Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
729
+
730
+ <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
731
+
732
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
733
+
734
+ </footer>
735
+ </div>
736
+ </div>
737
+
738
+ </section>
739
+
740
+ </div>
741
+
742
+
743
+ <script type="text/javascript">
744
+ jQuery(function () {
745
+ SphinxRtdTheme.Navigation.enable(true);
746
+ });
747
+ </script>
748
+
749
+
750
+
751
+
752
+
753
+
754
+ </body>
755
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/fasta.html ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8">
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
9
+
10
+ <title>dscript.fasta &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+ <!--[if lt IE 9]>
24
+ <script src="../../_static/js/html5shiv.min.js"></script>
25
+ <![endif]-->
26
+
27
+
28
+ <script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
29
+ <script src="../../_static/jquery.js"></script>
30
+ <script src="../../_static/underscore.js"></script>
31
+ <script src="../../_static/doctools.js"></script>
32
+ <script src="../../_static/language_data.js"></script>
33
+ <script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
34
+
35
+ <script type="text/javascript" src="../../_static/js/theme.js"></script>
36
+
37
+
38
+ <link rel="index" title="Index" href="../../genindex.html" />
39
+ <link rel="search" title="Search" href="../../search.html" />
40
+ </head>
41
+
42
+ <body class="wy-body-for-nav">
43
+
44
+
45
+ <div class="wy-grid-for-nav">
46
+
47
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
48
+ <div class="wy-side-scroll">
49
+ <div class="wy-side-nav-search" >
50
+
51
+
52
+
53
+ <a href="../../index.html" class="icon icon-home" alt="Documentation Home"> D-SCRIPT
54
+
55
+
56
+
57
+ </a>
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+ <div role="search">
66
+ <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
67
+ <input type="text" name="q" placeholder="Search docs" />
68
+ <input type="hidden" name="check_keywords" value="yes" />
69
+ <input type="hidden" name="area" value="default" />
70
+ </form>
71
+ </div>
72
+
73
+
74
+ </div>
75
+
76
+
77
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
78
+
79
+
80
+
81
+
82
+
83
+
84
+ <ul>
85
+ <li class="toctree-l1"><a class="reference internal" href="../../installation.html">Installation</a></li>
86
+ <li class="toctree-l1"><a class="reference internal" href="../../usage.html">Usage</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../../data.html">Data</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../../api/index.html">API</a></li>
89
+ </ul>
90
+
91
+
92
+
93
+ </div>
94
+
95
+ </div>
96
+ </nav>
97
+
98
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
99
+
100
+
101
+ <nav class="wy-nav-top" aria-label="top navigation">
102
+
103
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
104
+ <a href="../../index.html">D-SCRIPT</a>
105
+
106
+ </nav>
107
+
108
+
109
+ <div class="wy-nav-content">
110
+
111
+ <div class="rst-content">
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+ <div role="navigation" aria-label="breadcrumbs navigation">
130
+
131
+ <ul class="wy-breadcrumbs">
132
+
133
+ <li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
134
+
135
+ <li><a href="../index.html">Module code</a> &raquo;</li>
136
+
137
+ <li>dscript.fasta</li>
138
+
139
+
140
+ <li class="wy-breadcrumbs-aside">
141
+
142
+ </li>
143
+
144
+ </ul>
145
+
146
+
147
+ <hr/>
148
+ </div>
149
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
150
+ <div itemprop="articleBody">
151
+
152
+ <h1>Source code for dscript.fasta</h1><div class="highlight"><pre>
153
+ <div class="viewcode-block" id="parse"><a class="viewcode-back" href="../../api/index.html#dscript.fasta.parse">[docs]</a><span></span><span class="k">def</span> <span class="nf">parse</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s2">&quot;#&quot;</span><span class="p">):</span>
154
+ <span class="sd">&quot;&quot;&quot;</span>
155
+ <span class="sd"> Parse a file in ``.fasta`` format.</span>
156
+
157
+ <span class="sd"> :param f: Input file object</span>
158
+ <span class="sd"> :type f: _io.TextIOWrapper</span>
159
+ <span class="sd"> :param comment: Character used for comments</span>
160
+ <span class="sd"> :type comment: str</span>
161
+
162
+ <span class="sd"> :return: names, sequence</span>
163
+ <span class="sd"> :rtype: list[str], list[str]</span>
164
+ <span class="sd"> &quot;&quot;&quot;</span>
165
+ <span class="n">starter</span> <span class="o">=</span> <span class="s2">&quot;&gt;&quot;</span>
166
+ <span class="n">empty</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
167
+ <span class="k">if</span> <span class="s2">&quot;b&quot;</span> <span class="ow">in</span> <span class="n">f</span><span class="o">.</span><span class="n">mode</span><span class="p">:</span>
168
+ <span class="n">comment</span> <span class="o">=</span> <span class="sa">b</span><span class="s2">&quot;#&quot;</span>
169
+ <span class="n">starter</span> <span class="o">=</span> <span class="sa">b</span><span class="s2">&quot;&gt;&quot;</span>
170
+ <span class="n">empty</span> <span class="o">=</span> <span class="sa">b</span><span class="s2">&quot;&quot;</span>
171
+ <span class="n">names</span> <span class="o">=</span> <span class="p">[]</span>
172
+ <span class="n">sequences</span> <span class="o">=</span> <span class="p">[]</span>
173
+ <span class="n">name</span> <span class="o">=</span> <span class="kc">None</span>
174
+ <span class="n">sequence</span> <span class="o">=</span> <span class="p">[]</span>
175
+ <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">f</span><span class="p">:</span>
176
+ <span class="k">if</span> <span class="n">line</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="n">comment</span><span class="p">):</span>
177
+ <span class="k">continue</span>
178
+ <span class="n">line</span> <span class="o">=</span> <span class="n">line</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span>
179
+ <span class="k">if</span> <span class="n">line</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="n">starter</span><span class="p">):</span>
180
+ <span class="k">if</span> <span class="n">name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
181
+ <span class="n">names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
182
+ <span class="n">sequences</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">empty</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sequence</span><span class="p">))</span>
183
+ <span class="n">name</span> <span class="o">=</span> <span class="n">line</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>
184
+ <span class="n">sequence</span> <span class="o">=</span> <span class="p">[]</span>
185
+ <span class="k">else</span><span class="p">:</span>
186
+ <span class="n">sequence</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">line</span><span class="o">.</span><span class="n">upper</span><span class="p">())</span>
187
+ <span class="k">if</span> <span class="n">name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
188
+ <span class="n">names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
189
+ <span class="n">sequences</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">empty</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sequence</span><span class="p">))</span>
190
+
191
+ <span class="k">return</span> <span class="n">names</span><span class="p">,</span> <span class="n">sequences</span></div>
192
+
193
+
194
+ <div class="viewcode-block" id="parse_directory"><a class="viewcode-back" href="../../api/index.html#dscript.fasta.parse_directory">[docs]</a><span class="k">def</span> <span class="nf">parse_directory</span><span class="p">(</span><span class="n">directory</span><span class="p">,</span> <span class="n">extension</span><span class="o">=</span><span class="s2">&quot;.seq&quot;</span><span class="p">):</span>
195
+ <span class="sd">&quot;&quot;&quot;</span>
196
+ <span class="sd"> Parse all files in a directory ending with ``extension``.</span>
197
+
198
+ <span class="sd"> :param directory: Input directory</span>
199
+ <span class="sd"> :type directory: str</span>
200
+ <span class="sd"> :param extension: Extension of all files to read in</span>
201
+ <span class="sd"> :type extension: str</span>
202
+
203
+ <span class="sd"> :return: names, sequence</span>
204
+ <span class="sd"> :rtype: list[str], list[str]</span>
205
+ <span class="sd"> &quot;&quot;&quot;</span>
206
+ <span class="n">names</span> <span class="o">=</span> <span class="p">[]</span>
207
+ <span class="n">sequences</span> <span class="o">=</span> <span class="p">[]</span>
208
+
209
+ <span class="k">for</span> <span class="n">seqPath</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">directory</span><span class="p">):</span>
210
+ <span class="k">if</span> <span class="n">seqPath</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="n">extension</span><span class="p">):</span>
211
+ <span class="n">n</span><span class="p">,</span> <span class="n">s</span> <span class="o">=</span> <span class="n">parse</span><span class="p">(</span><span class="nb">open</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{directory}</span><span class="s2">/</span><span class="si">{seqPath}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="s2">&quot;rb&quot;</span><span class="p">))</span>
212
+ <span class="n">names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">n</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">strip</span><span class="p">())</span>
213
+ <span class="n">sequences</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">s</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">strip</span><span class="p">())</span>
214
+ <span class="k">return</span> <span class="n">names</span><span class="p">,</span> <span class="n">sequences</span></div>
215
+
216
+
217
+ <div class="viewcode-block" id="write"><a class="viewcode-back" href="../../api/index.html#dscript.fasta.write">[docs]</a><span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="n">nam</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="n">f</span><span class="p">):</span>
218
+ <span class="sd">&quot;&quot;&quot;</span>
219
+ <span class="sd"> Write a file in ``.fasta`` format.</span>
220
+
221
+ <span class="sd"> :param nam: List of names</span>
222
+ <span class="sd"> :type nam: list[str]</span>
223
+ <span class="sd"> :param seq: List of sequences</span>
224
+ <span class="sd"> :type seq: list[str]</span>
225
+ <span class="sd"> :param f: Output file object</span>
226
+ <span class="sd"> :type f: _io.TextIOWrapper</span>
227
+ <span class="sd"> &quot;&quot;&quot;</span>
228
+ <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">nam</span><span class="p">,</span> <span class="n">seq</span><span class="p">):</span>
229
+ <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;&gt;</span><span class="si">{}</span><span class="se">\n</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">n</span><span class="p">))</span>
230
+ <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">{}</span><span class="se">\n</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">s</span><span class="p">))</span></div>
231
+ </pre></div>
232
+
233
+ </div>
234
+
235
+ </div>
236
+ <footer>
237
+
238
+
239
+ <hr/>
240
+
241
+ <div role="contentinfo">
242
+ <p>
243
+
244
+ &copy; Copyright 2020, Samuel Sledzieski, Rohit Singh
245
+
246
+ </p>
247
+ </div>
248
+
249
+
250
+
251
+ Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a
252
+
253
+ <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a>
254
+
255
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
256
+
257
+ </footer>
258
+
259
+ </div>
260
+ </div>
261
+
262
+ </section>
263
+
264
+ </div>
265
+
266
+
267
+ <script type="text/javascript">
268
+ jQuery(function () {
269
+ SphinxRtdTheme.Navigation.enable(true);
270
+ });
271
+ </script>
272
+
273
+
274
+
275
+
276
+
277
+
278
+ </body>
279
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/language_model.html ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8" />
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
+
10
+ <title>dscript.language_model &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+ <!--[if lt IE 9]>
27
+ <script src="../../_static/js/html5shiv.min.js"></script>
28
+ <![endif]-->
29
+
30
+
31
+ <script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
32
+ <script src="../../_static/jquery.js"></script>
33
+ <script src="../../_static/underscore.js"></script>
34
+ <script src="../../_static/doctools.js"></script>
35
+
36
+ <script type="text/javascript" src="../../_static/js/theme.js"></script>
37
+
38
+
39
+ <link rel="index" title="Index" href="../../genindex.html" />
40
+ <link rel="search" title="Search" href="../../search.html" />
41
+ </head>
42
+
43
+ <body class="wy-body-for-nav">
44
+
45
+
46
+ <div class="wy-grid-for-nav">
47
+
48
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
49
+ <div class="wy-side-scroll">
50
+ <div class="wy-side-nav-search" >
51
+
52
+
53
+
54
+ <a href="../../index.html" class="icon icon-home"> D-SCRIPT
55
+
56
+
57
+
58
+ </a>
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+ <div role="search">
67
+ <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
68
+ <input type="text" name="q" placeholder="Search docs" />
69
+ <input type="hidden" name="check_keywords" value="yes" />
70
+ <input type="hidden" name="area" value="default" />
71
+ </form>
72
+ </div>
73
+
74
+
75
+ </div>
76
+
77
+
78
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
79
+
80
+
81
+
82
+
83
+
84
+
85
+ <ul>
86
+ <li class="toctree-l1"><a class="reference internal" href="../../installation.html">Installation</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../../usage.html">Usage</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../../data.html">Data</a></li>
89
+ <li class="toctree-l1"><a class="reference internal" href="../../api/index.html">API</a></li>
90
+ </ul>
91
+
92
+
93
+
94
+ </div>
95
+
96
+ </div>
97
+ </nav>
98
+
99
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
100
+
101
+
102
+ <nav class="wy-nav-top" aria-label="top navigation">
103
+
104
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
105
+ <a href="../../index.html">D-SCRIPT</a>
106
+
107
+ </nav>
108
+
109
+
110
+ <div class="wy-nav-content">
111
+
112
+ <div class="rst-content">
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+ <div role="navigation" aria-label="breadcrumbs navigation">
133
+
134
+ <ul class="wy-breadcrumbs">
135
+
136
+ <li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
137
+
138
+ <li><a href="../index.html">Module code</a> &raquo;</li>
139
+
140
+ <li>dscript.language_model</li>
141
+
142
+
143
+ <li class="wy-breadcrumbs-aside">
144
+
145
+ </li>
146
+
147
+ </ul>
148
+
149
+
150
+ <hr/>
151
+ </div>
152
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
153
+ <div itemprop="articleBody">
154
+
155
+ <h1>Source code for dscript.language_model</h1><div class="highlight"><pre>
156
+ <span></span><span class="kn">import</span> <span class="nn">os</span><span class="o">,</span> <span class="nn">sys</span>
157
+ <span class="kn">import</span> <span class="nn">subprocess</span> <span class="k">as</span> <span class="nn">sp</span>
158
+ <span class="kn">import</span> <span class="nn">random</span>
159
+ <span class="kn">import</span> <span class="nn">torch</span>
160
+ <span class="kn">import</span> <span class="nn">h5py</span>
161
+ <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
162
+ <span class="kn">from</span> <span class="nn">.fasta</span> <span class="kn">import</span> <span class="n">parse</span><span class="p">,</span> <span class="n">parse_directory</span><span class="p">,</span> <span class="n">write</span>
163
+ <span class="kn">from</span> <span class="nn">.pretrained</span> <span class="kn">import</span> <span class="n">get_pretrained</span>
164
+ <span class="kn">from</span> <span class="nn">.alphabets</span> <span class="kn">import</span> <span class="n">Uniprot21</span>
165
+ <span class="kn">from</span> <span class="nn">.models.embedding</span> <span class="kn">import</span> <span class="n">SkipLSTM</span>
166
+ <span class="kn">from</span> <span class="nn">datetime</span> <span class="kn">import</span> <span class="n">datetime</span>
167
+
168
+
169
+ <div class="viewcode-block" id="lm_embed"><a class="viewcode-back" href="../../api/index.html#dscript.language_model.lm_embed">[docs]</a><span class="k">def</span> <span class="nf">lm_embed</span><span class="p">(</span><span class="n">sequence</span><span class="p">,</span> <span class="n">use_cuda</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
170
+ <span class="sd">&quot;&quot;&quot;</span>
171
+ <span class="sd"> Embed a single sequence using pre-trained language model from `Bepler &amp; Berger &lt;https://github.com/tbepler/protein-sequence-embedding-iclr2019&gt;`_.</span>
172
+
173
+ <span class="sd"> :param sequence: Input sequence to be embedded</span>
174
+ <span class="sd"> :type sequence: str</span>
175
+ <span class="sd"> :param use_cuda: Whether to generate embeddings using GPU device [default: False]</span>
176
+ <span class="sd"> :type use_cuda: bool</span>
177
+ <span class="sd"> :return: Embedded sequence</span>
178
+ <span class="sd"> :rtype: torch.Tensor</span>
179
+ <span class="sd"> &quot;&quot;&quot;</span>
180
+
181
+ <span class="n">model</span> <span class="o">=</span> <span class="n">get_pretrained</span><span class="p">(</span><span class="s2">&quot;lm_v1&quot;</span><span class="p">)</span>
182
+ <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">proj</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span>
183
+ <span class="n">model</span><span class="o">.</span><span class="n">proj</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">100</span><span class="p">))</span>
184
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
185
+ <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
186
+ <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
187
+
188
+ <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
189
+ <span class="n">alphabet</span> <span class="o">=</span> <span class="n">Uniprot21</span><span class="p">()</span>
190
+ <span class="n">es</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">alphabet</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">sequence</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="s1">&#39;utf-8&#39;</span><span class="p">)))</span>
191
+ <span class="n">x</span> <span class="o">=</span> <span class="n">es</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
192
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
193
+ <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
194
+ <span class="n">z</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
195
+ <span class="k">return</span> <span class="n">z</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span></div>
196
+
197
+
198
+ <div class="viewcode-block" id="embed_from_fasta"><a class="viewcode-back" href="../../api/index.html#dscript.language_model.embed_from_fasta">[docs]</a><span class="k">def</span> <span class="nf">embed_from_fasta</span><span class="p">(</span><span class="n">fastaPath</span><span class="p">,</span> <span class="n">outputPath</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
199
+ <span class="sd">&quot;&quot;&quot;</span>
200
+ <span class="sd"> Embed sequences using pre-trained language model from `Bepler &amp; Berger &lt;https://github.com/tbepler/protein-sequence-embedding-iclr2019&gt;`_.</span>
201
+
202
+ <span class="sd"> :param fastaPath: Input sequence file (``.fasta`` format)</span>
203
+ <span class="sd"> :type fastaPath: str</span>
204
+ <span class="sd"> :param outputPath: Output embedding file (``.h5`` format)</span>
205
+ <span class="sd"> :type outputPath: str</span>
206
+ <span class="sd"> :param device: Compute device to use for embeddings [default: 0]</span>
207
+ <span class="sd"> :type device: int</span>
208
+ <span class="sd"> :param verbose: Print embedding progress</span>
209
+ <span class="sd"> :type verbose: bool</span>
210
+ <span class="sd"> &quot;&quot;&quot;</span>
211
+ <span class="n">use_cuda</span> <span class="o">=</span> <span class="p">(</span><span class="n">device</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span>
212
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
213
+ <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
214
+ <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
215
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;# Using CUDA device </span><span class="si">{</span><span class="n">device</span><span class="si">}</span><span class="s2"> - </span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">get_device_name</span><span class="p">(</span><span class="n">device</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
216
+ <span class="k">else</span><span class="p">:</span>
217
+ <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
218
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Using CPU&quot;</span><span class="p">)</span>
219
+
220
+ <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
221
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Loading Model...&quot;</span><span class="p">)</span>
222
+ <span class="n">model</span> <span class="o">=</span> <span class="n">get_pretrained</span><span class="p">(</span><span class="s2">&quot;lm_v1&quot;</span><span class="p">)</span>
223
+ <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">proj</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span>
224
+ <span class="n">model</span><span class="o">.</span><span class="n">proj</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">100</span><span class="p">))</span>
225
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
226
+ <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
227
+
228
+ <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
229
+ <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
230
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Loading Sequences...&quot;</span><span class="p">)</span>
231
+ <span class="n">names</span><span class="p">,</span> <span class="n">seqs</span> <span class="o">=</span> <span class="n">parse</span><span class="p">(</span><span class="nb">open</span><span class="p">(</span><span class="n">fastaPath</span><span class="p">,</span> <span class="s2">&quot;rb&quot;</span><span class="p">))</span>
232
+ <span class="n">alphabet</span> <span class="o">=</span> <span class="n">Uniprot21</span><span class="p">()</span>
233
+ <span class="n">encoded_seqs</span> <span class="o">=</span> <span class="p">[]</span>
234
+ <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">seqs</span><span class="p">):</span>
235
+ <span class="n">es</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">alphabet</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">s</span><span class="p">))</span>
236
+ <span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
237
+ <span class="n">es</span> <span class="o">=</span> <span class="n">es</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
238
+ <span class="n">encoded_seqs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">es</span><span class="p">)</span>
239
+ <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
240
+ <span class="n">num_seqs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">encoded_seqs</span><span class="p">)</span>
241
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# </span><span class="si">{}</span><span class="s2"> Sequences Loaded&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">num_seqs</span><span class="p">))</span>
242
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Approximate Storage Required (varies by average sequence length): ~</span><span class="si">{}</span><span class="s2">GB&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">num_seqs</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span><span class="o">/</span><span class="mi">125</span><span class="p">)))</span>
243
+
244
+ <span class="n">h5fi</span> <span class="o">=</span> <span class="n">h5py</span><span class="o">.</span><span class="n">File</span><span class="p">(</span><span class="n">outputPath</span><span class="p">,</span> <span class="s2">&quot;w&quot;</span><span class="p">)</span>
245
+
246
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;# Storing to </span><span class="si">{}</span><span class="s2">...&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">outputPath</span><span class="p">))</span>
247
+ <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
248
+ <span class="k">try</span><span class="p">:</span>
249
+ <span class="k">for</span> <span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">names</span><span class="p">,</span> <span class="n">encoded_seqs</span><span class="p">),</span><span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">names</span><span class="p">)):</span>
250
+ <span class="n">name</span> <span class="o">=</span> <span class="n">n</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span>
251
+ <span class="k">if</span> <span class="ow">not</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">h5fi</span><span class="p">:</span>
252
+ <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
253
+ <span class="n">z</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
254
+ <span class="n">h5fi</span><span class="o">.</span><span class="n">create_dataset</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">data</span><span class="o">=</span><span class="n">z</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">compression</span><span class="o">=</span><span class="s2">&quot;lzf&quot;</span><span class="p">)</span>
255
+ <span class="k">except</span> <span class="ne">KeyboardInterrupt</span><span class="p">:</span>
256
+ <span class="n">h5fi</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
257
+ <span class="n">sys</span><span class="o">.</span><span class="n">exit</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
258
+ <span class="n">h5fi</span><span class="o">.</span><span class="n">close</span><span class="p">()</span></div>
259
+
260
+
261
+ <div class="viewcode-block" id="embed_from_directory"><a class="viewcode-back" href="../../api/index.html#dscript.language_model.embed_from_directory">[docs]</a><span class="k">def</span> <span class="nf">embed_from_directory</span><span class="p">(</span><span class="n">directory</span><span class="p">,</span> <span class="n">outputPath</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">extension</span><span class="o">=</span><span class="s2">&quot;.seq&quot;</span><span class="p">):</span>
262
+ <span class="sd">&quot;&quot;&quot;</span>
263
+ <span class="sd"> Embed all files in a directory in ``.fasta`` format using pre-trained language model from `Bepler &amp; Berger &lt;https://github.com/tbepler/protein-sequence-embedding-iclr2019&gt;`_.</span>
264
+
265
+ <span class="sd"> :param directory: Input directory (``.fasta`` format)</span>
266
+ <span class="sd"> :type directory: str</span>
267
+ <span class="sd"> :param outputPath: Output embedding file (``.h5`` format)</span>
268
+ <span class="sd"> :type outputPath: str</span>
269
+ <span class="sd"> :param device: Compute device to use for embeddings [default: 0]</span>
270
+ <span class="sd"> :type device: int</span>
271
+ <span class="sd"> :param verbose: Print embedding progress</span>
272
+ <span class="sd"> :type verbose: bool</span>
273
+ <span class="sd"> :param extension: Extension of all files to read in</span>
274
+ <span class="sd"> :type extension: str</span>
275
+ <span class="sd"> &quot;&quot;&quot;</span>
276
+ <span class="n">nam</span><span class="p">,</span> <span class="n">seq</span> <span class="o">=</span> <span class="n">parse_directory</span><span class="p">(</span><span class="n">directory</span><span class="p">,</span> <span class="n">extension</span><span class="o">=</span><span class="n">extension</span><span class="p">)</span>
277
+ <span class="n">fastaPath</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">directory</span><span class="si">}</span><span class="s2">/allSeqs.fa&quot;</span>
278
+ <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">fastaPath</span><span class="p">):</span>
279
+ <span class="n">fastaPath</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">fastaPath</span><span class="si">}</span><span class="s2">.</span><span class="si">{</span><span class="nb">int</span><span class="p">(</span><span class="n">datetime</span><span class="o">.</span><span class="n">utcnow</span><span class="p">()</span><span class="o">.</span><span class="n">timestamp</span><span class="p">())</span><span class="si">}</span><span class="s2">&quot;</span>
280
+ <span class="n">write</span><span class="p">(</span><span class="n">nam</span><span class="p">,</span> <span class="n">seq</span><span class="p">,</span> <span class="nb">open</span><span class="p">(</span><span class="n">fastaPath</span><span class="p">,</span> <span class="s2">&quot;w&quot;</span><span class="p">))</span>
281
+ <span class="n">embed_from_fasta</span><span class="p">(</span><span class="n">fastaPath</span><span class="p">,</span> <span class="n">outputPath</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">verbose</span><span class="p">)</span></div>
282
+ </pre></div>
283
+
284
+ </div>
285
+
286
+ </div>
287
+ <footer>
288
+
289
+ <hr/>
290
+
291
+ <div role="contentinfo">
292
+ <p>
293
+ &#169; Copyright 2020, Samuel Sledzieski, Rohit Singh.
294
+
295
+ </p>
296
+ </div>
297
+
298
+
299
+
300
+ Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
301
+
302
+ <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
303
+
304
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
305
+
306
+ </footer>
307
+ </div>
308
+ </div>
309
+
310
+ </section>
311
+
312
+ </div>
313
+
314
+
315
+ <script type="text/javascript">
316
+ jQuery(function () {
317
+ SphinxRtdTheme.Navigation.enable(true);
318
+ });
319
+ </script>
320
+
321
+
322
+
323
+
324
+
325
+
326
+ </body>
327
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/models/contact.html ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8">
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
9
+
10
+ <title>dscript.models.contact &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+ <!--[if lt IE 9]>
24
+ <script src="../../../_static/js/html5shiv.min.js"></script>
25
+ <![endif]-->
26
+
27
+
28
+ <script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
29
+ <script src="../../../_static/jquery.js"></script>
30
+ <script src="../../../_static/underscore.js"></script>
31
+ <script src="../../../_static/doctools.js"></script>
32
+ <script src="../../../_static/language_data.js"></script>
33
+ <script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
34
+
35
+ <script type="text/javascript" src="../../../_static/js/theme.js"></script>
36
+
37
+
38
+ <link rel="index" title="Index" href="../../../genindex.html" />
39
+ <link rel="search" title="Search" href="../../../search.html" />
40
+ </head>
41
+
42
+ <body class="wy-body-for-nav">
43
+
44
+
45
+ <div class="wy-grid-for-nav">
46
+
47
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
48
+ <div class="wy-side-scroll">
49
+ <div class="wy-side-nav-search" >
50
+
51
+
52
+
53
+ <a href="../../../index.html" class="icon icon-home" alt="Documentation Home"> D-SCRIPT
54
+
55
+
56
+
57
+ </a>
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+ <div role="search">
66
+ <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
67
+ <input type="text" name="q" placeholder="Search docs" />
68
+ <input type="hidden" name="check_keywords" value="yes" />
69
+ <input type="hidden" name="area" value="default" />
70
+ </form>
71
+ </div>
72
+
73
+
74
+ </div>
75
+
76
+
77
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
78
+
79
+
80
+
81
+
82
+
83
+
84
+ <ul>
85
+ <li class="toctree-l1"><a class="reference internal" href="../../../installation.html">Installation</a></li>
86
+ <li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../../../data.html">Data</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../../../api/index.html">API</a></li>
89
+ </ul>
90
+
91
+
92
+
93
+ </div>
94
+
95
+ </div>
96
+ </nav>
97
+
98
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
99
+
100
+
101
+ <nav class="wy-nav-top" aria-label="top navigation">
102
+
103
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
104
+ <a href="../../../index.html">D-SCRIPT</a>
105
+
106
+ </nav>
107
+
108
+
109
+ <div class="wy-nav-content">
110
+
111
+ <div class="rst-content">
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+ <div role="navigation" aria-label="breadcrumbs navigation">
130
+
131
+ <ul class="wy-breadcrumbs">
132
+
133
+ <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
134
+
135
+ <li><a href="../../index.html">Module code</a> &raquo;</li>
136
+
137
+ <li>dscript.models.contact</li>
138
+
139
+
140
+ <li class="wy-breadcrumbs-aside">
141
+
142
+ </li>
143
+
144
+ </ul>
145
+
146
+
147
+ <hr/>
148
+ </div>
149
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
150
+ <div itemprop="articleBody">
151
+
152
+ <h1>Source code for dscript.models.contact</h1><div class="highlight"><pre>
153
+ <span></span><span class="sd">&quot;&quot;&quot;</span>
154
+ <span class="sd">Contact model classes.</span>
155
+ <span class="sd">&quot;&quot;&quot;</span>
156
+
157
+ <span class="kn">import</span> <span class="nn">torch</span>
158
+ <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
159
+ <span class="kn">import</span> <span class="nn">torch.functional</span> <span class="k">as</span> <span class="nn">F</span>
160
+
161
+
162
+ <div class="viewcode-block" id="FullyConnected"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.contact.FullyConnected">[docs]</a><span class="k">class</span> <span class="nc">FullyConnected</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
163
+ <span class="sd">&quot;&quot;&quot;</span>
164
+ <span class="sd"> Performs part 1 of Contact Prediction Module. Takes embeddings from Projection module and produces broadcast tensor.</span>
165
+
166
+ <span class="sd"> Input embeddings of dimension :math:`d` are combined into a :math:`2d` length MLP input :math:`z_{cat}`, where :math:`z_{cat} = [z_0 \\ominus z_1 | z_0 \\odot z_1]`</span>
167
+
168
+ <span class="sd"> :param embed_dim: Output dimension of `dscript.models.embedding &lt;#module-dscript.models.embedding&gt;`_ model :math:`d` [default: 100]</span>
169
+ <span class="sd"> :type embed_dim: int</span>
170
+ <span class="sd"> :param hidden_dim: Hidden dimension :math:`h` [default: 50]</span>
171
+ <span class="sd"> :type hidden_dim: int</span>
172
+ <span class="sd"> :param activation: Activation function for broadcast tensor [default: torch.nn.ReLU()]</span>
173
+ <span class="sd"> :type activation: torch.nn.Module</span>
174
+ <span class="sd"> &quot;&quot;&quot;</span>
175
+
176
+ <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()):</span>
177
+ <span class="nb">super</span><span class="p">(</span><span class="n">FullyConnected</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
178
+
179
+ <span class="bp">self</span><span class="o">.</span><span class="n">D</span> <span class="o">=</span> <span class="n">embed_dim</span>
180
+ <span class="bp">self</span><span class="o">.</span><span class="n">H</span> <span class="o">=</span> <span class="n">hidden_dim</span>
181
+ <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">D</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">H</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
182
+ <span class="bp">self</span><span class="o">.</span><span class="n">batchnorm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">H</span><span class="p">)</span>
183
+ <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span>
184
+
185
+ <div class="viewcode-block" id="FullyConnected.forward"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.contact.FullyConnected.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">):</span>
186
+ <span class="sd">&quot;&quot;&quot;</span>
187
+ <span class="sd"> :param z0: Projection module embedding :math:`(b \\times N \\times d)`</span>
188
+ <span class="sd"> :type z0: torch.Tensor</span>
189
+ <span class="sd"> :param z1: Projection module embedding :math:`(b \\times M \\times d)`</span>
190
+ <span class="sd"> :type z1: torch.Tensor</span>
191
+ <span class="sd"> :return: Predicted broadcast tensor :math:`(b \\times N \\times M \\times h)`</span>
192
+ <span class="sd"> :rtype: torch.Tensor</span>
193
+ <span class="sd"> &quot;&quot;&quot;</span>
194
+
195
+ <span class="c1"># z0 is (b,N,d), z1 is (b,M,d)</span>
196
+ <span class="n">z0</span> <span class="o">=</span> <span class="n">z0</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
197
+ <span class="n">z1</span> <span class="o">=</span> <span class="n">z1</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
198
+ <span class="c1"># z0 is (b,d,N), z1 is (b,d,M)</span>
199
+
200
+ <span class="n">z_dif</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">z0</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="o">-</span> <span class="n">z1</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">2</span><span class="p">))</span>
201
+ <span class="n">z_mul</span> <span class="o">=</span> <span class="n">z0</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="o">*</span> <span class="n">z1</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
202
+ <span class="n">z_cat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">z_dif</span><span class="p">,</span> <span class="n">z_mul</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
203
+
204
+ <span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">z_cat</span><span class="p">)</span>
205
+ <span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
206
+ <span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batchnorm</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
207
+
208
+ <span class="k">return</span> <span class="n">b</span></div></div>
209
+
210
+
211
+ <div class="viewcode-block" id="ContactCNN"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.contact.ContactCNN">[docs]</a><span class="k">class</span> <span class="nc">ContactCNN</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
212
+ <span class="sd">&quot;&quot;&quot;</span>
213
+ <span class="sd"> Residue Contact Prediction Module. Takes embeddings from Projection module and produces contact map, output of Contact module.</span>
214
+
215
+ <span class="sd"> :param embed_dim: Output dimension of `dscript.models.embedding &lt;#module-dscript.models.embedding&gt;`_ model :math:`d` [default: 100]</span>
216
+ <span class="sd"> :type embed_dim: int</span>
217
+ <span class="sd"> :param hidden_dim: Hidden dimension :math:`h` [default: 50]</span>
218
+ <span class="sd"> :type hidden_dim: int</span>
219
+ <span class="sd"> :param width: Width of convolutional filter :math:`2w+1` [default: 7]</span>
220
+ <span class="sd"> :type width: int</span>
221
+ <span class="sd"> :param activation: Activation function for final contact map [default: torch.nn.Sigmoid()]</span>
222
+ <span class="sd"> :type activation: torch.nn.Module</span>
223
+ <span class="sd"> &quot;&quot;&quot;</span>
224
+
225
+ <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">width</span><span class="o">=</span><span class="mi">7</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">()):</span>
226
+ <span class="nb">super</span><span class="p">(</span><span class="n">ContactCNN</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
227
+
228
+ <span class="bp">self</span><span class="o">.</span><span class="n">hidden</span> <span class="o">=</span> <span class="n">FullyConnected</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">)</span>
229
+ <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">hidden_dim</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="n">width</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span>
230
+ <span class="bp">self</span><span class="o">.</span><span class="n">batchnorm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
231
+ <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span>
232
+ <span class="bp">self</span><span class="o">.</span><span class="n">clip</span><span class="p">()</span>
233
+
234
+ <div class="viewcode-block" id="ContactCNN.clip"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.contact.ContactCNN.clip">[docs]</a> <span class="k">def</span> <span class="nf">clip</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
235
+ <span class="sd">&quot;&quot;&quot;</span>
236
+ <span class="sd"> Force the convolutional layer to be transpose invariant.</span>
237
+
238
+ <span class="sd"> :meta private:</span>
239
+ <span class="sd"> &quot;&quot;&quot;</span>
240
+
241
+ <span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="o">.</span><span class="n">weight</span>
242
+ <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="p">[:]</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span><span class="n">w</span> <span class="o">+</span> <span class="n">w</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span></div>
243
+
244
+ <div class="viewcode-block" id="ContactCNN.forward"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.contact.ContactCNN.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">):</span>
245
+ <span class="sd">&quot;&quot;&quot;</span>
246
+ <span class="sd"> :param z0: Projection module embedding :math:`(b \\times N \\times d)`</span>
247
+ <span class="sd"> :type z0: torch.Tensor</span>
248
+ <span class="sd"> :param z1: Projection module embedding :math:`(b \\times M \\times d)`</span>
249
+ <span class="sd"> :type z1: torch.Tensor</span>
250
+ <span class="sd"> :return: Predicted contact map :math:`(b \\times N \\times M)`</span>
251
+ <span class="sd"> :rtype: torch.Tensor</span>
252
+ <span class="sd"> &quot;&quot;&quot;</span>
253
+ <span class="n">B</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">)</span>
254
+ <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">C</span><span class="p">)</span></div>
255
+
256
+ <div class="viewcode-block" id="ContactCNN.broadcast"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.contact.ContactCNN.broadcast">[docs]</a> <span class="k">def</span> <span class="nf">broadcast</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">):</span>
257
+ <span class="sd">&quot;&quot;&quot;</span>
258
+ <span class="sd"> Calls `dscript.models.contact.FullyConnected &lt;#module-dscript.models.contact.FullyConnected&gt;`_.</span>
259
+
260
+ <span class="sd"> :param z0: Projection module embedding :math:`(b \\times N \\times d)`</span>
261
+ <span class="sd"> :type z0: torch.Tensor</span>
262
+ <span class="sd"> :param z1: Projection module embedding :math:`(b \\times M \\times d)`</span>
263
+ <span class="sd"> :type z1: torch.Tensor</span>
264
+ <span class="sd"> :return: Predicted contact broadcast tensor :math:`(b \\times N \\times M \\times h)`</span>
265
+ <span class="sd"> :rtype: torch.Tensor</span>
266
+ <span class="sd"> &quot;&quot;&quot;</span>
267
+ <span class="n">B</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden</span><span class="p">(</span><span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">)</span>
268
+ <span class="k">return</span> <span class="n">B</span></div>
269
+
270
+ <div class="viewcode-block" id="ContactCNN.predict"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.contact.ContactCNN.predict">[docs]</a> <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">B</span><span class="p">):</span>
271
+ <span class="sd">&quot;&quot;&quot;</span>
272
+ <span class="sd"> Predict contact map from broadcast tensor.</span>
273
+
274
+ <span class="sd"> :param B: Predicted contact broadcast :math:`(b \\times N \\times M \\times h)`</span>
275
+ <span class="sd"> :type B: torch.Tensor</span>
276
+ <span class="sd"> :return: Predicted contact map :math:`(b \\times N \\times M)`</span>
277
+ <span class="sd"> :rtype: torch.Tensor</span>
278
+ <span class="sd"> &quot;&quot;&quot;</span>
279
+ <span class="n">C</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
280
+ <span class="n">C</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batchnorm</span><span class="p">(</span><span class="n">C</span><span class="p">)</span>
281
+ <span class="n">C</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">C</span><span class="p">)</span>
282
+ <span class="k">return</span> <span class="n">C</span></div></div>
283
+ </pre></div>
284
+
285
+ </div>
286
+
287
+ </div>
288
+ <footer>
289
+
290
+
291
+ <hr/>
292
+
293
+ <div role="contentinfo">
294
+ <p>
295
+
296
+ &copy; Copyright 2020, Samuel Sledzieski, Rohit Singh
297
+
298
+ </p>
299
+ </div>
300
+
301
+
302
+
303
+ Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a
304
+
305
+ <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a>
306
+
307
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
308
+
309
+ </footer>
310
+
311
+ </div>
312
+ </div>
313
+
314
+ </section>
315
+
316
+ </div>
317
+
318
+
319
+ <script type="text/javascript">
320
+ jQuery(function () {
321
+ SphinxRtdTheme.Navigation.enable(true);
322
+ });
323
+ </script>
324
+
325
+
326
+
327
+
328
+
329
+
330
+ </body>
331
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/models/embedding.html ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8" />
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
+
10
+ <title>dscript.models.embedding &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+ <!--[if lt IE 9]>
27
+ <script src="../../../_static/js/html5shiv.min.js"></script>
28
+ <![endif]-->
29
+
30
+
31
+ <script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
32
+ <script src="../../../_static/jquery.js"></script>
33
+ <script src="../../../_static/underscore.js"></script>
34
+ <script src="../../../_static/doctools.js"></script>
35
+
36
+ <script type="text/javascript" src="../../../_static/js/theme.js"></script>
37
+
38
+
39
+ <link rel="index" title="Index" href="../../../genindex.html" />
40
+ <link rel="search" title="Search" href="../../../search.html" />
41
+ </head>
42
+
43
+ <body class="wy-body-for-nav">
44
+
45
+
46
+ <div class="wy-grid-for-nav">
47
+
48
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
49
+ <div class="wy-side-scroll">
50
+ <div class="wy-side-nav-search" >
51
+
52
+
53
+
54
+ <a href="../../../index.html" class="icon icon-home"> D-SCRIPT
55
+
56
+
57
+
58
+ </a>
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+ <div role="search">
67
+ <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
68
+ <input type="text" name="q" placeholder="Search docs" />
69
+ <input type="hidden" name="check_keywords" value="yes" />
70
+ <input type="hidden" name="area" value="default" />
71
+ </form>
72
+ </div>
73
+
74
+
75
+ </div>
76
+
77
+
78
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
79
+
80
+
81
+
82
+
83
+
84
+
85
+ <ul>
86
+ <li class="toctree-l1"><a class="reference internal" href="../../../installation.html">Installation</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../../../data.html">Data</a></li>
89
+ <li class="toctree-l1"><a class="reference internal" href="../../../api/index.html">API</a></li>
90
+ </ul>
91
+
92
+
93
+
94
+ </div>
95
+
96
+ </div>
97
+ </nav>
98
+
99
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
100
+
101
+
102
+ <nav class="wy-nav-top" aria-label="top navigation">
103
+
104
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
105
+ <a href="../../../index.html">D-SCRIPT</a>
106
+
107
+ </nav>
108
+
109
+
110
+ <div class="wy-nav-content">
111
+
112
+ <div class="rst-content">
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+ <div role="navigation" aria-label="breadcrumbs navigation">
133
+
134
+ <ul class="wy-breadcrumbs">
135
+
136
+ <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
137
+
138
+ <li><a href="../../index.html">Module code</a> &raquo;</li>
139
+
140
+ <li>dscript.models.embedding</li>
141
+
142
+
143
+ <li class="wy-breadcrumbs-aside">
144
+
145
+ </li>
146
+
147
+ </ul>
148
+
149
+
150
+ <hr/>
151
+ </div>
152
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
153
+ <div itemprop="articleBody">
154
+
155
+ <h1>Source code for dscript.models.embedding</h1><div class="highlight"><pre>
156
+ <span></span><span class="sd">&quot;&quot;&quot;</span>
157
+ <span class="sd">Embedding model classes.</span>
158
+ <span class="sd">&quot;&quot;&quot;</span>
159
+
160
+ <span class="kn">import</span> <span class="nn">torch</span>
161
+ <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
162
+ <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
163
+ <span class="kn">from</span> <span class="nn">torch.nn.utils.rnn</span> <span class="kn">import</span> <span class="n">PackedSequence</span>
164
+
165
+
166
+ <div class="viewcode-block" id="IdentityEmbed"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.embedding.IdentityEmbed">[docs]</a><span class="k">class</span> <span class="nc">IdentityEmbed</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
167
+ <span class="sd">&quot;&quot;&quot;</span>
168
+ <span class="sd"> Does not reduce the dimension of the language model embeddings, just passes them through to the contact model.</span>
169
+ <span class="sd"> &quot;&quot;&quot;</span>
170
+ <div class="viewcode-block" id="IdentityEmbed.forward"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.embedding.IdentityEmbed.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
171
+ <span class="sd">&quot;&quot;&quot;</span>
172
+ <span class="sd"> :param x: Input language model embedding :math:`(b \\times N \\times d_0)`</span>
173
+ <span class="sd"> :type x: torch.Tensor</span>
174
+ <span class="sd"> :return: Same embedding</span>
175
+ <span class="sd"> :rtype: torch.Tensor</span>
176
+ <span class="sd"> &quot;&quot;&quot;</span>
177
+ <span class="k">return</span> <span class="n">x</span></div></div>
178
+
179
+
180
+ <div class="viewcode-block" id="FullyConnectedEmbed"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.embedding.FullyConnectedEmbed">[docs]</a><span class="k">class</span> <span class="nc">FullyConnectedEmbed</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
181
+ <span class="sd">&quot;&quot;&quot;</span>
182
+ <span class="sd"> Protein Projection Module. Takes embedding from language model and outputs low-dimensional interaction aware projection.</span>
183
+
184
+ <span class="sd"> :param nin: Size of language model output</span>
185
+ <span class="sd"> :type nin: int</span>
186
+ <span class="sd"> :param nout: Dimension of projection</span>
187
+ <span class="sd"> :type nout: int</span>
188
+ <span class="sd"> :param dropout: Proportion of weights to drop out [default: 0.5]</span>
189
+ <span class="sd"> :type dropout: float</span>
190
+ <span class="sd"> :param activation: Activation for linear projection model</span>
191
+ <span class="sd"> :type activation: torch.nn.Module</span>
192
+ <span class="sd"> &quot;&quot;&quot;</span>
193
+ <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">nin</span><span class="p">,</span> <span class="n">nout</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()):</span>
194
+ <span class="nb">super</span><span class="p">(</span><span class="n">FullyConnectedEmbed</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
195
+ <span class="bp">self</span><span class="o">.</span><span class="n">nin</span> <span class="o">=</span> <span class="n">nin</span>
196
+ <span class="bp">self</span><span class="o">.</span><span class="n">nout</span> <span class="o">=</span> <span class="n">nout</span>
197
+ <span class="bp">self</span><span class="o">.</span><span class="n">dropout_p</span> <span class="o">=</span> <span class="n">dropout</span>
198
+
199
+ <span class="bp">self</span><span class="o">.</span><span class="n">transform</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">nin</span><span class="p">,</span> <span class="n">nout</span><span class="p">)</span>
200
+ <span class="bp">self</span><span class="o">.</span><span class="n">drop</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dropout_p</span><span class="p">)</span>
201
+ <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span>
202
+
203
+ <div class="viewcode-block" id="FullyConnectedEmbed.forward"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.embedding.FullyConnectedEmbed.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
204
+ <span class="sd">&quot;&quot;&quot;</span>
205
+ <span class="sd"> :param x: Input language model embedding :math:`(b \\times N \\times d_0)`</span>
206
+ <span class="sd"> :type x: torch.Tensor</span>
207
+ <span class="sd"> :return: Low dimensional projection of embedding</span>
208
+ <span class="sd"> :rtype: torch.Tensor</span>
209
+ <span class="sd"> &quot;&quot;&quot;</span>
210
+ <span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
211
+ <span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
212
+ <span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
213
+ <span class="k">return</span> <span class="n">t</span></div></div>
214
+
215
+
216
+ <div class="viewcode-block" id="SkipLSTM"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.embedding.SkipLSTM">[docs]</a><span class="k">class</span> <span class="nc">SkipLSTM</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
217
+ <span class="sd">&quot;&quot;&quot;</span>
218
+ <span class="sd"> Language model from `Bepler &amp; Berger &lt;https://github.com/tbepler/protein-sequence-embedding-iclr2019&gt;`_.</span>
219
+
220
+ <span class="sd"> Loaded with pre-trained weights in embedding function.</span>
221
+
222
+ <span class="sd"> :param nin: Input dimension of amino acid one-hot [default: 21]</span>
223
+ <span class="sd"> :type nin: int</span>
224
+ <span class="sd"> :param nout: Output dimension of final layer [default: 100]</span>
225
+ <span class="sd"> :type nout: int</span>
226
+ <span class="sd"> :param hidden_dim: Size of hidden dimension [default: 1024]</span>
227
+ <span class="sd"> :type hidden_dim: int</span>
228
+ <span class="sd"> :param num_layers: Number of stacked LSTM models [default: 3]</span>
229
+ <span class="sd"> :type num_layers: int</span>
230
+ <span class="sd"> :param dropout: Proportion of weights to drop out [default: 0]</span>
231
+ <span class="sd"> :type dropout: float</span>
232
+ <span class="sd"> :param bidirectional: Whether to use biLSTM vs. LSTM</span>
233
+ <span class="sd"> :type bidirectional: bool</span>
234
+ <span class="sd"> &quot;&quot;&quot;</span>
235
+ <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">nin</span><span class="o">=</span><span class="mi">21</span><span class="p">,</span> <span class="n">nout</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">bidirectional</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
236
+ <span class="nb">super</span><span class="p">(</span><span class="n">SkipLSTM</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
237
+
238
+ <span class="bp">self</span><span class="o">.</span><span class="n">nin</span> <span class="o">=</span> <span class="n">nin</span>
239
+ <span class="bp">self</span><span class="o">.</span><span class="n">nout</span> <span class="o">=</span> <span class="n">nout</span>
240
+
241
+ <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span>
242
+
243
+ <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">()</span>
244
+ <span class="n">dim</span> <span class="o">=</span> <span class="n">nin</span>
245
+ <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">):</span>
246
+ <span class="n">f</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">bidirectional</span><span class="o">=</span><span class="n">bidirectional</span><span class="p">)</span>
247
+ <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
248
+ <span class="k">if</span> <span class="n">bidirectional</span><span class="p">:</span>
249
+ <span class="n">dim</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">hidden_dim</span>
250
+ <span class="k">else</span><span class="p">:</span>
251
+ <span class="n">dim</span> <span class="o">=</span> <span class="n">hidden_dim</span>
252
+
253
+ <span class="n">n</span> <span class="o">=</span> <span class="n">hidden_dim</span> <span class="o">*</span> <span class="n">num_layers</span> <span class="o">+</span> <span class="n">nin</span>
254
+ <span class="k">if</span> <span class="n">bidirectional</span><span class="p">:</span>
255
+ <span class="n">n</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">hidden_dim</span> <span class="o">*</span> <span class="n">num_layers</span> <span class="o">+</span> <span class="n">nin</span>
256
+
257
+ <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">nout</span><span class="p">)</span>
258
+
259
+ <div class="viewcode-block" id="SkipLSTM.to_one_hot"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.embedding.SkipLSTM.to_one_hot">[docs]</a> <span class="k">def</span> <span class="nf">to_one_hot</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
260
+ <span class="sd">&quot;&quot;&quot;</span>
261
+ <span class="sd"> Transform numeric encoded amino acid vector to one-hot encoded vector</span>
262
+
263
+ <span class="sd"> :param x: Input numeric amino acid encoding :math:`(N)`</span>
264
+ <span class="sd"> :type x: torch.Tensor</span>
265
+ <span class="sd"> :return: One-hot encoding vector :math:`(N \\times n_{in})`</span>
266
+ <span class="sd"> :rtype: torch.Tensor</span>
267
+ <span class="sd"> &quot;&quot;&quot;</span>
268
+ <span class="n">packed</span> <span class="o">=</span> <span class="nb">type</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="ow">is</span> <span class="n">PackedSequence</span>
269
+ <span class="k">if</span> <span class="n">packed</span><span class="p">:</span>
270
+ <span class="n">one_hot</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">nin</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">zero_</span><span class="p">()</span>
271
+ <span class="n">one_hot</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
272
+ <span class="n">one_hot</span> <span class="o">=</span> <span class="n">PackedSequence</span><span class="p">(</span><span class="n">one_hot</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">batch_sizes</span><span class="p">)</span>
273
+ <span class="k">else</span><span class="p">:</span>
274
+ <span class="n">one_hot</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">nin</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">zero_</span><span class="p">()</span>
275
+ <span class="n">one_hot</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
276
+ <span class="k">return</span> <span class="n">one_hot</span></div>
277
+
278
+ <div class="viewcode-block" id="SkipLSTM.transform"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.embedding.SkipLSTM.transform">[docs]</a> <span class="k">def</span> <span class="nf">transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
279
+ <span class="sd">&quot;&quot;&quot;</span>
280
+ <span class="sd"> :param x: Input numeric amino acid encoding :math:`(N)`</span>
281
+ <span class="sd"> :type x: torch.Tensor</span>
282
+ <span class="sd"> :return: Concatenation of all hidden layers :math:`(N \\times (n_{in} + 2 \\times \\text{num_layers} \\times \\text{hidden_dim}))`</span>
283
+ <span class="sd"> :rtype: torch.Tensor</span>
284
+ <span class="sd"> &quot;&quot;&quot;</span>
285
+ <span class="n">one_hot</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_one_hot</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
286
+ <span class="n">hs</span> <span class="o">=</span> <span class="p">[</span><span class="n">one_hot</span><span class="p">]</span> <span class="c1"># []</span>
287
+ <span class="n">h_</span> <span class="o">=</span> <span class="n">one_hot</span>
288
+ <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
289
+ <span class="n">h</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">h_</span><span class="p">)</span>
290
+ <span class="c1"># h = self.dropout(h)</span>
291
+ <span class="n">hs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
292
+ <span class="n">h_</span> <span class="o">=</span> <span class="n">h</span>
293
+ <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="ow">is</span> <span class="n">PackedSequence</span><span class="p">:</span>
294
+ <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">z</span><span class="o">.</span><span class="n">data</span> <span class="k">for</span> <span class="n">z</span> <span class="ow">in</span> <span class="n">hs</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
295
+ <span class="n">h</span> <span class="o">=</span> <span class="n">PackedSequence</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">batch_sizes</span><span class="p">)</span>
296
+ <span class="k">else</span><span class="p">:</span>
297
+ <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">z</span> <span class="k">for</span> <span class="n">z</span> <span class="ow">in</span> <span class="n">hs</span><span class="p">],</span> <span class="mi">2</span><span class="p">)</span>
298
+ <span class="k">return</span> <span class="n">h</span></div>
299
+
300
+ <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
301
+ <span class="sd">&quot;&quot;&quot;</span>
302
+ <span class="sd"> :meta private:</span>
303
+ <span class="sd"> &quot;&quot;&quot;</span>
304
+ <span class="n">one_hot</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_one_hot</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
305
+ <span class="n">hs</span> <span class="o">=</span> <span class="p">[</span><span class="n">one_hot</span><span class="p">]</span>
306
+ <span class="n">h_</span> <span class="o">=</span> <span class="n">one_hot</span>
307
+
308
+ <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
309
+ <span class="n">h</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">h_</span><span class="p">)</span>
310
+ <span class="c1"># h = self.dropout(h)</span>
311
+ <span class="n">hs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
312
+ <span class="n">h_</span> <span class="o">=</span> <span class="n">h</span>
313
+
314
+ <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="ow">is</span> <span class="n">PackedSequence</span><span class="p">:</span>
315
+ <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">z</span><span class="o">.</span><span class="n">data</span> <span class="k">for</span> <span class="n">z</span> <span class="ow">in</span> <span class="n">hs</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
316
+ <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
317
+ <span class="n">z</span> <span class="o">=</span> <span class="n">PackedSequence</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">batch_sizes</span><span class="p">)</span>
318
+ <span class="k">else</span><span class="p">:</span>
319
+ <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">z</span> <span class="k">for</span> <span class="n">z</span> <span class="ow">in</span> <span class="n">hs</span><span class="p">],</span> <span class="mi">2</span><span class="p">)</span>
320
+ <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">h</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">h</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">)))</span>
321
+ <span class="n">z</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
322
+
323
+ <span class="k">return</span> <span class="n">z</span></div>
324
+ </pre></div>
325
+
326
+ </div>
327
+
328
+ </div>
329
+ <footer>
330
+
331
+ <hr/>
332
+
333
+ <div role="contentinfo">
334
+ <p>
335
+ &#169; Copyright 2020, Samuel Sledzieski, Rohit Singh.
336
+
337
+ </p>
338
+ </div>
339
+
340
+
341
+
342
+ Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
343
+
344
+ <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
345
+
346
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
347
+
348
+ </footer>
349
+ </div>
350
+ </div>
351
+
352
+ </section>
353
+
354
+ </div>
355
+
356
+
357
+ <script type="text/javascript">
358
+ jQuery(function () {
359
+ SphinxRtdTheme.Navigation.enable(true);
360
+ });
361
+ </script>
362
+
363
+
364
+
365
+
366
+
367
+
368
+ </body>
369
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/models/interaction.html ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8" />
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
+
10
+ <title>dscript.models.interaction &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+ <!--[if lt IE 9]>
27
+ <script src="../../../_static/js/html5shiv.min.js"></script>
28
+ <![endif]-->
29
+
30
+
31
+ <script type="text/javascript" id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
32
+ <script src="../../../_static/jquery.js"></script>
33
+ <script src="../../../_static/underscore.js"></script>
34
+ <script src="../../../_static/doctools.js"></script>
35
+
36
+ <script type="text/javascript" src="../../../_static/js/theme.js"></script>
37
+
38
+
39
+ <link rel="index" title="Index" href="../../../genindex.html" />
40
+ <link rel="search" title="Search" href="../../../search.html" />
41
+ </head>
42
+
43
+ <body class="wy-body-for-nav">
44
+
45
+
46
+ <div class="wy-grid-for-nav">
47
+
48
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
49
+ <div class="wy-side-scroll">
50
+ <div class="wy-side-nav-search" >
51
+
52
+
53
+
54
+ <a href="../../../index.html" class="icon icon-home"> D-SCRIPT
55
+
56
+
57
+
58
+ </a>
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+ <div role="search">
67
+ <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
68
+ <input type="text" name="q" placeholder="Search docs" />
69
+ <input type="hidden" name="check_keywords" value="yes" />
70
+ <input type="hidden" name="area" value="default" />
71
+ </form>
72
+ </div>
73
+
74
+
75
+ </div>
76
+
77
+
78
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
79
+
80
+
81
+
82
+
83
+
84
+
85
+ <ul>
86
+ <li class="toctree-l1"><a class="reference internal" href="../../../installation.html">Installation</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../../../usage.html">Usage</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../../../data.html">Data</a></li>
89
+ <li class="toctree-l1"><a class="reference internal" href="../../../api/index.html">API</a></li>
90
+ </ul>
91
+
92
+
93
+
94
+ </div>
95
+
96
+ </div>
97
+ </nav>
98
+
99
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
100
+
101
+
102
+ <nav class="wy-nav-top" aria-label="top navigation">
103
+
104
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
105
+ <a href="../../../index.html">D-SCRIPT</a>
106
+
107
+ </nav>
108
+
109
+
110
+ <div class="wy-nav-content">
111
+
112
+ <div class="rst-content">
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+ <div role="navigation" aria-label="breadcrumbs navigation">
133
+
134
+ <ul class="wy-breadcrumbs">
135
+
136
+ <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
137
+
138
+ <li><a href="../../index.html">Module code</a> &raquo;</li>
139
+
140
+ <li>dscript.models.interaction</li>
141
+
142
+
143
+ <li class="wy-breadcrumbs-aside">
144
+
145
+ </li>
146
+
147
+ </ul>
148
+
149
+
150
+ <hr/>
151
+ </div>
152
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
153
+ <div itemprop="articleBody">
154
+
155
+ <h1>Source code for dscript.models.interaction</h1><div class="highlight"><pre>
156
+ <span></span><span class="sd">&quot;&quot;&quot;</span>
157
+ <span class="sd">Interaction model classes.</span>
158
+ <span class="sd">&quot;&quot;&quot;</span>
159
+
160
+ <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
161
+
162
+ <span class="kn">import</span> <span class="nn">torch</span>
163
+ <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
164
+ <span class="kn">import</span> <span class="nn">torch.functional</span> <span class="k">as</span> <span class="nn">F</span>
165
+
166
+
167
+ <div class="viewcode-block" id="LogisticActivation"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.interaction.LogisticActivation">[docs]</a><span class="k">class</span> <span class="nc">LogisticActivation</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
168
+ <span class="sd">&quot;&quot;&quot;</span>
169
+ <span class="sd"> Implementation of Generalized Sigmoid</span>
170
+ <span class="sd"> Applies the element-wise function:</span>
171
+
172
+ <span class="sd"> :math:`\\sigma(x) = \\frac{1}{1 + \\exp(-k(x-x_0))}`</span>
173
+
174
+ <span class="sd"> :param x0: The value of the sigmoid midpoint</span>
175
+ <span class="sd"> :type x0: float</span>
176
+ <span class="sd"> :param k: The slope of the sigmoid - trainable - :math:`k \\geq 0`</span>
177
+ <span class="sd"> :type k: float</span>
178
+ <span class="sd"> :param train: Whether :math:`k` is a trainable parameter</span>
179
+ <span class="sd"> :type train: bool</span>
180
+ <span class="sd"> &quot;&quot;&quot;</span>
181
+
182
+ <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x0</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
183
+ <span class="nb">super</span><span class="p">(</span><span class="n">LogisticActivation</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
184
+ <span class="bp">self</span><span class="o">.</span><span class="n">x0</span> <span class="o">=</span> <span class="n">x0</span>
185
+ <span class="bp">self</span><span class="o">.</span><span class="n">k</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">([</span><span class="nb">float</span><span class="p">(</span><span class="n">k</span><span class="p">)]))</span>
186
+ <span class="bp">self</span><span class="o">.</span><span class="n">k</span><span class="o">.</span><span class="n">requiresGrad</span> <span class="o">=</span> <span class="n">train</span>
187
+
188
+ <div class="viewcode-block" id="LogisticActivation.forward"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.interaction.LogisticActivation.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
189
+ <span class="sd">&quot;&quot;&quot;</span>
190
+ <span class="sd"> Applies the function to the input elementwise</span>
191
+
192
+ <span class="sd"> :param x: :math:`(N \\times *)` where :math:`*` means, any number of additional dimensions</span>
193
+ <span class="sd"> :type x: torch.Tensor</span>
194
+ <span class="sd"> :return: :math:`(N \\times *)`, same shape as the input</span>
195
+ <span class="sd"> :rtype: torch.Tensor</span>
196
+ <span class="sd"> &quot;&quot;&quot;</span>
197
+ <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">k</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">x0</span><span class="p">))),</span> <span class="nb">min</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
198
+ <span class="k">return</span> <span class="n">out</span></div>
199
+
200
+ <span class="k">def</span> <span class="nf">clip</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
201
+ <span class="sd">&quot;&quot;&quot;</span>
202
+ <span class="sd"> Restricts sigmoid slope :math:`k` to be greater than or equal to 0, if :math:`k` is trained.</span>
203
+
204
+ <span class="sd"> :meta private:</span>
205
+ <span class="sd"> &quot;&quot;&quot;</span>
206
+ <span class="bp">self</span><span class="o">.</span><span class="n">k</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">clamp_</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></div>
207
+
208
+
209
+ <div class="viewcode-block" id="ModelInteraction"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.interaction.ModelInteraction">[docs]</a><span class="k">class</span> <span class="nc">ModelInteraction</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
210
+ <span class="sd">&quot;&quot;&quot;</span>
211
+ <span class="sd"> Main D-SCRIPT model. Contains an embedding and contact model and offers access to those models. Computes pooling operations on contact map to generate interaction probability.</span>
212
+
213
+ <span class="sd"> :param embedding: Embedding model</span>
214
+ <span class="sd"> :type embedding: dscript.models.embedding.FullyConnectedEmbed</span>
215
+ <span class="sd"> :param contact: Contact model</span>
216
+ <span class="sd"> :type contact: dscript.models.contact.ContactCNN</span>
217
+ <span class="sd"> :param use_cuda: Whether the model should be run on GPU</span>
218
+ <span class="sd"> :type use_cuda: bool</span>
219
+ <span class="sd"> :param pool_size: width of max-pool [default 9]</span>
220
+ <span class="sd"> :type pool_size: bool</span>
221
+ <span class="sd"> :param theta_init: initialization value of :math:`\\theta` for weight matrix [default: 1]</span>
222
+ <span class="sd"> :type theta_init: float</span>
223
+ <span class="sd"> :param lambda_init: initialization value of :math:`\\lambda` for weight matrix [default: 0]</span>
224
+ <span class="sd"> :type lambda_init: float</span>
225
+ <span class="sd"> :param gamma_init: initialization value of :math:`\\gamma` for global pooling [default: 0]</span>
226
+ <span class="sd"> :type gamma_init: float</span>
227
+ <span class="sd"> :param use_W: whether to use the weighting matrix [default: True]</span>
228
+ <span class="sd"> :type use_W: bool</span>
229
+ <span class="sd"> &quot;&quot;&quot;</span>
230
+
231
+ <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
232
+ <span class="bp">self</span><span class="p">,</span>
233
+ <span class="n">embedding</span><span class="p">,</span>
234
+ <span class="n">contact</span><span class="p">,</span>
235
+ <span class="n">pool_size</span><span class="o">=</span><span class="mi">9</span><span class="p">,</span>
236
+ <span class="n">theta_init</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
237
+ <span class="n">lambda_init</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
238
+ <span class="n">gamma_init</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
239
+ <span class="n">use_W</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
240
+ <span class="p">):</span>
241
+ <span class="nb">super</span><span class="p">(</span><span class="n">ModelInteraction</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
242
+ <span class="bp">self</span><span class="o">.</span><span class="n">use_W</span> <span class="o">=</span> <span class="n">use_W</span>
243
+ <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">LogisticActivation</span><span class="p">(</span><span class="n">x0</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
244
+
245
+ <span class="bp">self</span><span class="o">.</span><span class="n">embedding</span> <span class="o">=</span> <span class="n">embedding</span>
246
+ <span class="bp">self</span><span class="o">.</span><span class="n">contact</span> <span class="o">=</span> <span class="n">contact</span>
247
+
248
+ <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_W</span><span class="p">:</span>
249
+ <span class="bp">self</span><span class="o">.</span><span class="n">theta</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">([</span><span class="n">theta_init</span><span class="p">]))</span>
250
+ <span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">([</span><span class="n">lambda_init</span><span class="p">]))</span>
251
+
252
+ <span class="bp">self</span><span class="o">.</span><span class="n">maxPool</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="n">pool_size</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="n">pool_size</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span>
253
+ <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">([</span><span class="n">gamma_init</span><span class="p">]))</span>
254
+
255
+ <span class="bp">self</span><span class="o">.</span><span class="n">clip</span><span class="p">()</span>
256
+
257
+ <span class="k">def</span> <span class="nf">clip</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
258
+ <span class="sd">&quot;&quot;&quot;</span>
259
+ <span class="sd"> Clamp model values</span>
260
+
261
+ <span class="sd"> :meta private:</span>
262
+ <span class="sd"> &quot;&quot;&quot;</span>
263
+ <span class="bp">self</span><span class="o">.</span><span class="n">contact</span><span class="o">.</span><span class="n">clip</span><span class="p">()</span>
264
+
265
+ <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_W</span><span class="p">:</span>
266
+ <span class="bp">self</span><span class="o">.</span><span class="n">theta</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">clamp_</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
267
+ <span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">clamp_</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
268
+
269
+ <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">clamp_</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
270
+
271
+ <div class="viewcode-block" id="ModelInteraction.embed"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.interaction.ModelInteraction.embed">[docs]</a> <span class="k">def</span> <span class="nf">embed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span>
272
+ <span class="sd">&quot;&quot;&quot;</span>
273
+ <span class="sd"> Project down input language model embeddings into low dimension using projection module</span>
274
+
275
+ <span class="sd"> :param z: Language model embedding :math:`(b \\times N \\times d_0)`</span>
276
+ <span class="sd"> :type z: torch.Tensor</span>
277
+ <span class="sd"> :return: D-SCRIPT projection :math:`(b \\times N \\times d)`</span>
278
+ <span class="sd"> :rtype: torch.Tensor</span>
279
+ <span class="sd"> &quot;&quot;&quot;</span>
280
+ <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
281
+ <span class="k">return</span> <span class="n">z</span>
282
+ <span class="k">else</span><span class="p">:</span>
283
+ <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">z</span><span class="p">)</span></div>
284
+
285
+ <div class="viewcode-block" id="ModelInteraction.cpred"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.interaction.ModelInteraction.cpred">[docs]</a> <span class="k">def</span> <span class="nf">cpred</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">):</span>
286
+ <span class="sd">&quot;&quot;&quot;</span>
287
+ <span class="sd"> Project down input language model embeddings into low dimension using projection module</span>
288
+
289
+ <span class="sd"> :param z0: Language model embedding :math:`(b \\times N \\times d_0)`</span>
290
+ <span class="sd"> :type z0: torch.Tensor</span>
291
+ <span class="sd"> :param z1: Language model embedding :math:`(b \\times N \\times d_0)`</span>
292
+ <span class="sd"> :type z1: torch.Tensor</span>
293
+ <span class="sd"> :return: Predicted contact map :math:`(b \\times N \\times M)`</span>
294
+ <span class="sd"> :rtype: torch.Tensor</span>
295
+ <span class="sd"> &quot;&quot;&quot;</span>
296
+ <span class="n">e0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed</span><span class="p">(</span><span class="n">z0</span><span class="p">)</span>
297
+ <span class="n">e1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed</span><span class="p">(</span><span class="n">z1</span><span class="p">)</span>
298
+ <span class="n">B</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">contact</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="n">e0</span><span class="p">,</span> <span class="n">e1</span><span class="p">)</span>
299
+ <span class="n">C</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">contact</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
300
+ <span class="k">return</span> <span class="n">C</span></div>
301
+
302
+ <div class="viewcode-block" id="ModelInteraction.map_predict"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.interaction.ModelInteraction.map_predict">[docs]</a> <span class="k">def</span> <span class="nf">map_predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">):</span>
303
+ <span class="sd">&quot;&quot;&quot;</span>
304
+ <span class="sd"> Project down input language model embeddings into low dimension using projection module</span>
305
+
306
+ <span class="sd"> :param z0: Language model embedding :math:`(b \\times N \\times d_0)`</span>
307
+ <span class="sd"> :type z0: torch.Tensor</span>
308
+ <span class="sd"> :param z1: Language model embedding :math:`(b \\times N \\times d_0)`</span>
309
+ <span class="sd"> :type z1: torch.Tensor</span>
310
+ <span class="sd"> :return: Predicted contact map, predicted probability of interaction :math:`(b \\times N \\times d_0), (1)`</span>
311
+ <span class="sd"> :rtype: torch.Tensor, torch.Tensor</span>
312
+ <span class="sd"> &quot;&quot;&quot;</span>
313
+
314
+ <span class="n">C</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cpred</span><span class="p">(</span><span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">)</span>
315
+
316
+ <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_W</span><span class="p">:</span>
317
+ <span class="c1"># Create contact weighting matrix</span>
318
+ <span class="n">N</span><span class="p">,</span> <span class="n">M</span> <span class="o">=</span> <span class="n">C</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]</span>
319
+
320
+ <span class="n">x1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">N</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">-</span> <span class="p">((</span><span class="n">N</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">))</span> <span class="o">/</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="p">((</span><span class="n">N</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)))</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
321
+ <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="s1">&#39;cuda&#39;</span><span class="p">:</span>
322
+ <span class="n">x1</span> <span class="o">=</span> <span class="n">x1</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
323
+ <span class="n">x1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">*</span> <span class="n">x1</span><span class="p">)</span>
324
+
325
+ <span class="n">x2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">M</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">-</span> <span class="p">((</span><span class="n">M</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">))</span> <span class="o">/</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="p">((</span><span class="n">M</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)))</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
326
+ <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="s1">&#39;cuda&#39;</span><span class="p">:</span>
327
+ <span class="n">x2</span> <span class="o">=</span> <span class="n">x2</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
328
+ <span class="n">x2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lambda_</span> <span class="o">*</span> <span class="n">x2</span><span class="p">)</span>
329
+
330
+ <span class="n">W</span> <span class="o">=</span> <span class="n">x1</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">x2</span>
331
+ <span class="n">W</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">theta</span><span class="p">)</span> <span class="o">*</span> <span class="n">W</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">theta</span>
332
+
333
+ <span class="n">yhat</span> <span class="o">=</span> <span class="n">C</span> <span class="o">*</span> <span class="n">W</span>
334
+
335
+ <span class="k">else</span><span class="p">:</span>
336
+ <span class="n">yhat</span> <span class="o">=</span> <span class="n">C</span>
337
+
338
+ <span class="n">yhat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">maxPool</span><span class="p">(</span><span class="n">yhat</span><span class="p">)</span>
339
+
340
+ <span class="c1"># Mean of contact predictions where p_ij &gt; mu + gamma*sigma</span>
341
+ <span class="n">mu</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">yhat</span><span class="p">)</span>
342
+ <span class="n">sigma</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">var</span><span class="p">(</span><span class="n">yhat</span><span class="p">)</span>
343
+ <span class="n">Q</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">yhat</span> <span class="o">-</span> <span class="n">mu</span> <span class="o">-</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">*</span> <span class="n">sigma</span><span class="p">))</span>
344
+ <span class="n">phat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">Q</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sign</span><span class="p">(</span><span class="n">Q</span><span class="p">))</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
345
+ <span class="n">phat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">phat</span><span class="p">)</span>
346
+ <span class="k">return</span> <span class="n">C</span><span class="p">,</span> <span class="n">phat</span></div>
347
+
348
+ <div class="viewcode-block" id="ModelInteraction.predict"><a class="viewcode-back" href="../../../api/dscript.models.html#dscript.models.interaction.ModelInteraction.predict">[docs]</a> <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">):</span>
349
+ <span class="sd">&quot;&quot;&quot;</span>
350
+ <span class="sd"> Project down input language model embeddings into low dimension using projection module</span>
351
+
352
+ <span class="sd"> :param z0: Language model embedding :math:`(b \\times N \\times d_0)`</span>
353
+ <span class="sd"> :type z0: torch.Tensor</span>
354
+ <span class="sd"> :param z1: Language model embedding :math:`(b \\times N \\times d_0)`</span>
355
+ <span class="sd"> :type z1: torch.Tensor</span>
356
+ <span class="sd"> :return: Predicted probability of interaction</span>
357
+ <span class="sd"> :rtype: torch.Tensor, torch.Tensor</span>
358
+ <span class="sd"> &quot;&quot;&quot;</span>
359
+ <span class="n">_</span><span class="p">,</span> <span class="n">phat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">map_predict</span><span class="p">(</span><span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">)</span>
360
+ <span class="k">return</span> <span class="n">phat</span></div>
361
+
362
+
363
+ <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">):</span>
364
+ <span class="sd">&quot;&quot;&quot;</span>
365
+ <span class="sd"> :meta private:</span>
366
+ <span class="sd"> &quot;&quot;&quot;</span>
367
+ <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">z0</span><span class="p">,</span> <span class="n">z1</span><span class="p">)</span></div>
368
+ </pre></div>
369
+
370
+ </div>
371
+
372
+ </div>
373
+ <footer>
374
+
375
+ <hr/>
376
+
377
+ <div role="contentinfo">
378
+ <p>
379
+ &#169; Copyright 2020, Samuel Sledzieski, Rohit Singh.
380
+
381
+ </p>
382
+ </div>
383
+
384
+
385
+
386
+ Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
387
+
388
+ <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
389
+
390
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
391
+
392
+ </footer>
393
+ </div>
394
+ </div>
395
+
396
+ </section>
397
+
398
+ </div>
399
+
400
+
401
+ <script type="text/javascript">
402
+ jQuery(function () {
403
+ SphinxRtdTheme.Navigation.enable(true);
404
+ });
405
+ </script>
406
+
407
+
408
+
409
+
410
+
411
+
412
+ </body>
413
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/pretrained.html ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8" />
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
+
10
+ <title>dscript.pretrained &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+ <!--[if lt IE 9]>
27
+ <script src="../../_static/js/html5shiv.min.js"></script>
28
+ <![endif]-->
29
+
30
+
31
+ <script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
32
+ <script src="../../_static/jquery.js"></script>
33
+ <script src="../../_static/underscore.js"></script>
34
+ <script src="../../_static/doctools.js"></script>
35
+
36
+ <script type="text/javascript" src="../../_static/js/theme.js"></script>
37
+
38
+
39
+ <link rel="index" title="Index" href="../../genindex.html" />
40
+ <link rel="search" title="Search" href="../../search.html" />
41
+ </head>
42
+
43
+ <body class="wy-body-for-nav">
44
+
45
+
46
+ <div class="wy-grid-for-nav">
47
+
48
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
49
+ <div class="wy-side-scroll">
50
+ <div class="wy-side-nav-search" >
51
+
52
+
53
+
54
+ <a href="../../index.html" class="icon icon-home"> D-SCRIPT
55
+
56
+
57
+
58
+ </a>
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+ <div role="search">
67
+ <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
68
+ <input type="text" name="q" placeholder="Search docs" />
69
+ <input type="hidden" name="check_keywords" value="yes" />
70
+ <input type="hidden" name="area" value="default" />
71
+ </form>
72
+ </div>
73
+
74
+
75
+ </div>
76
+
77
+
78
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
79
+
80
+
81
+
82
+
83
+
84
+
85
+ <ul>
86
+ <li class="toctree-l1"><a class="reference internal" href="../../installation.html">Installation</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../../usage.html">Usage</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../../data.html">Data</a></li>
89
+ <li class="toctree-l1"><a class="reference internal" href="../../api/index.html">API</a></li>
90
+ </ul>
91
+
92
+
93
+
94
+ </div>
95
+
96
+ </div>
97
+ </nav>
98
+
99
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
100
+
101
+
102
+ <nav class="wy-nav-top" aria-label="top navigation">
103
+
104
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
105
+ <a href="../../index.html">D-SCRIPT</a>
106
+
107
+ </nav>
108
+
109
+
110
+ <div class="wy-nav-content">
111
+
112
+ <div class="rst-content">
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+ <div role="navigation" aria-label="breadcrumbs navigation">
133
+
134
+ <ul class="wy-breadcrumbs">
135
+
136
+ <li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
137
+
138
+ <li><a href="../index.html">Module code</a> &raquo;</li>
139
+
140
+ <li>dscript.pretrained</li>
141
+
142
+
143
+ <li class="wy-breadcrumbs-aside">
144
+
145
+ </li>
146
+
147
+ </ul>
148
+
149
+
150
+ <hr/>
151
+ </div>
152
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
153
+ <div itemprop="articleBody">
154
+
155
+ <h1>Source code for dscript.pretrained</h1><div class="highlight"><pre>
156
+ <span></span><span class="kn">import</span> <span class="nn">os</span><span class="o">,</span> <span class="nn">sys</span>
157
+ <span class="kn">import</span> <span class="nn">torch</span>
158
+
159
+ <span class="kn">from</span> <span class="nn">.models.embedding</span> <span class="kn">import</span> <span class="n">FullyConnectedEmbed</span><span class="p">,</span> <span class="n">SkipLSTM</span>
160
+ <span class="kn">from</span> <span class="nn">.models.contact</span> <span class="kn">import</span> <span class="n">ContactCNN</span>
161
+ <span class="kn">from</span> <span class="nn">.models.interaction</span> <span class="kn">import</span> <span class="n">ModelInteraction</span>
162
+
163
+
164
+ <span class="k">def</span> <span class="nf">build_lm_1</span><span class="p">(</span><span class="n">state_dict_path</span><span class="p">):</span>
165
+ <span class="sd">&quot;&quot;&quot;</span>
166
+ <span class="sd"> :meta private:</span>
167
+ <span class="sd"> &quot;&quot;&quot;</span>
168
+ <span class="n">model</span> <span class="o">=</span> <span class="n">SkipLSTM</span><span class="p">(</span><span class="mi">21</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
169
+ <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">state_dict_path</span><span class="p">)</span>
170
+ <span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">)</span>
171
+ <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
172
+ <span class="k">return</span> <span class="n">model</span>
173
+
174
+
175
+ <span class="k">def</span> <span class="nf">build_human_1</span><span class="p">(</span><span class="n">state_dict_path</span><span class="p">):</span>
176
+ <span class="sd">&quot;&quot;&quot;</span>
177
+ <span class="sd"> :meta private:</span>
178
+ <span class="sd"> &quot;&quot;&quot;</span>
179
+ <span class="n">embModel</span> <span class="o">=</span> <span class="n">FullyConnectedEmbed</span><span class="p">(</span><span class="mi">6165</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span>
180
+ <span class="n">conModel</span> <span class="o">=</span> <span class="n">ContactCNN</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">7</span><span class="p">)</span>
181
+ <span class="n">model</span> <span class="o">=</span> <span class="n">ModelInteraction</span><span class="p">(</span><span class="n">embModel</span><span class="p">,</span> <span class="n">conModel</span><span class="p">,</span> <span class="n">use_W</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">pool_size</span><span class="o">=</span><span class="mi">9</span><span class="p">)</span>
182
+ <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">state_dict_path</span><span class="p">)</span>
183
+ <span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">)</span>
184
+ <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
185
+ <span class="k">return</span> <span class="n">model</span>
186
+
187
+
188
+ <span class="n">VALID_MODELS</span> <span class="o">=</span> <span class="p">{</span>
189
+ <span class="s2">&quot;lm_v1&quot;</span><span class="p">:</span> <span class="n">build_lm_1</span><span class="p">,</span>
190
+ <span class="s2">&quot;human_v1&quot;</span><span class="p">:</span> <span class="n">build_human_1</span>
191
+ <span class="p">}</span>
192
+
193
+
194
+ <div class="viewcode-block" id="get_state_dict"><a class="viewcode-back" href="../../api/index.html#dscript.pretrained.get_state_dict">[docs]</a><span class="k">def</span> <span class="nf">get_state_dict</span><span class="p">(</span><span class="n">version</span><span class="o">=</span><span class="s2">&quot;human_v1&quot;</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
195
+ <span class="sd">&quot;&quot;&quot;</span>
196
+ <span class="sd"> Download a pre-trained model if not already exists on local device.</span>
197
+
198
+ <span class="sd"> :param version: Version of trained model to download [default: human_1]</span>
199
+ <span class="sd"> :type version: str</span>
200
+ <span class="sd"> :param verbose: Print model download status on stdout [default: True]</span>
201
+ <span class="sd"> :type verbose: bool</span>
202
+ <span class="sd"> :return: Path to state dictionary for pre-trained language model</span>
203
+ <span class="sd"> :rtype: str</span>
204
+ <span class="sd"> &quot;&quot;&quot;</span>
205
+ <span class="n">state_dict_basename</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;dscript_</span><span class="si">{</span><span class="n">version</span><span class="si">}</span><span class="s2">.pt&quot;</span>
206
+ <span class="n">state_dict_basedir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">realpath</span><span class="p">(</span><span class="vm">__file__</span><span class="p">))</span>
207
+ <span class="n">state_dict_fullname</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">state_dict_basedir</span><span class="si">}</span><span class="s2">/</span><span class="si">{</span><span class="n">state_dict_basename</span><span class="si">}</span><span class="s2">&quot;</span>
208
+ <span class="n">state_dict_url</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;http://cb.csail.mit.edu/cb/dscript/data/models/</span><span class="si">{</span><span class="n">state_dict_basename</span><span class="si">}</span><span class="s2">&quot;</span>
209
+ <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">state_dict_fullname</span><span class="p">):</span>
210
+ <span class="k">try</span><span class="p">:</span>
211
+ <span class="kn">import</span> <span class="nn">urllib.request</span>
212
+ <span class="kn">import</span> <span class="nn">shutil</span>
213
+ <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Downloading model </span><span class="si">{</span><span class="n">version</span><span class="si">}</span><span class="s2"> from </span><span class="si">{</span><span class="n">state_dict_url</span><span class="si">}</span><span class="s2">...&quot;</span><span class="p">)</span>
214
+ <span class="k">with</span> <span class="n">urllib</span><span class="o">.</span><span class="n">request</span><span class="o">.</span><span class="n">urlopen</span><span class="p">(</span><span class="n">state_dict_url</span><span class="p">)</span> <span class="k">as</span> <span class="n">response</span><span class="p">,</span> <span class="nb">open</span><span class="p">(</span><span class="n">state_dict_fullname</span><span class="p">,</span> <span class="s1">&#39;wb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">out_file</span><span class="p">:</span>
215
+ <span class="n">shutil</span><span class="o">.</span><span class="n">copyfileobj</span><span class="p">(</span><span class="n">response</span><span class="p">,</span> <span class="n">out_file</span><span class="p">)</span>
216
+ <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
217
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Unable to download model - </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">e</span><span class="p">))</span>
218
+ <span class="n">sys</span><span class="o">.</span><span class="n">exit</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
219
+ <span class="k">return</span> <span class="n">state_dict_fullname</span></div>
220
+
221
+
222
+ <div class="viewcode-block" id="get_pretrained"><a class="viewcode-back" href="../../api/index.html#dscript.pretrained.get_pretrained">[docs]</a><span class="k">def</span> <span class="nf">get_pretrained</span><span class="p">(</span><span class="n">version</span><span class="o">=</span><span class="s2">&quot;human_v1&quot;</span><span class="p">):</span>
223
+ <span class="sd">&quot;&quot;&quot;</span>
224
+ <span class="sd"> Get pre-trained model object.</span>
225
+
226
+ <span class="sd"> Currently Available Models</span>
227
+ <span class="sd"> ==========================</span>
228
+
229
+ <span class="sd"> See the `documentation &lt;https://d-script.readthedocs.io/en/main/data.html#trained-models&gt;`_ for most up-to-date list.</span>
230
+
231
+ <span class="sd"> - ``lm_v1`` - Language model from `Bepler &amp; Berger &lt;https://github.com/tbepler/protein-sequence-embedding-iclr2019&gt;`_.</span>
232
+ <span class="sd"> - ``human_v1`` - Human trained model from D-SCRIPT manuscript.</span>
233
+
234
+ <span class="sd"> Default: ``human_v1``</span>
235
+
236
+ <span class="sd"> :param version: Version of pre-trained model to get</span>
237
+ <span class="sd"> :type version: str</span>
238
+ <span class="sd"> :return: Pre-trained model</span>
239
+ <span class="sd"> :rtype: dscript.models.*</span>
240
+ <span class="sd"> &quot;&quot;&quot;</span>
241
+ <span class="k">if</span> <span class="ow">not</span> <span class="n">version</span> <span class="ow">in</span> <span class="n">VALID_MODELS</span><span class="p">:</span>
242
+ <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Model </span><span class="si">{}</span><span class="s2"> does not exist&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">version</span><span class="p">))</span>
243
+
244
+ <span class="n">state_dict_path</span> <span class="o">=</span> <span class="n">get_state_dict</span><span class="p">(</span><span class="n">version</span><span class="p">)</span>
245
+ <span class="k">return</span> <span class="n">VALID_MODELS</span><span class="p">[</span><span class="n">version</span><span class="p">](</span><span class="n">state_dict_path</span><span class="p">)</span></div>
246
+ </pre></div>
247
+
248
+ </div>
249
+
250
+ </div>
251
+ <footer>
252
+
253
+ <hr/>
254
+
255
+ <div role="contentinfo">
256
+ <p>
257
+ &#169; Copyright 2020, Samuel Sledzieski, Rohit Singh.
258
+
259
+ </p>
260
+ </div>
261
+
262
+
263
+
264
+ Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
265
+
266
+ <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
267
+
268
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
269
+
270
+ </footer>
271
+ </div>
272
+ </div>
273
+
274
+ </section>
275
+
276
+ </div>
277
+
278
+
279
+ <script type="text/javascript">
280
+ jQuery(function () {
281
+ SphinxRtdTheme.Navigation.enable(true);
282
+ });
283
+ </script>
284
+
285
+
286
+
287
+
288
+
289
+
290
+ </body>
291
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/dscript/utils.html ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8" />
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
+
10
+ <title>dscript.utils &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+ <!--[if lt IE 9]>
27
+ <script src="../../_static/js/html5shiv.min.js"></script>
28
+ <![endif]-->
29
+
30
+
31
+ <script type="text/javascript" id="documentation_options" data-url_root="../../" src="../../_static/documentation_options.js"></script>
32
+ <script src="../../_static/jquery.js"></script>
33
+ <script src="../../_static/underscore.js"></script>
34
+ <script src="../../_static/doctools.js"></script>
35
+
36
+ <script type="text/javascript" src="../../_static/js/theme.js"></script>
37
+
38
+
39
+ <link rel="index" title="Index" href="../../genindex.html" />
40
+ <link rel="search" title="Search" href="../../search.html" />
41
+ </head>
42
+
43
+ <body class="wy-body-for-nav">
44
+
45
+
46
+ <div class="wy-grid-for-nav">
47
+
48
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
49
+ <div class="wy-side-scroll">
50
+ <div class="wy-side-nav-search" >
51
+
52
+
53
+
54
+ <a href="../../index.html" class="icon icon-home"> D-SCRIPT
55
+
56
+
57
+
58
+ </a>
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+ <div role="search">
67
+ <form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
68
+ <input type="text" name="q" placeholder="Search docs" />
69
+ <input type="hidden" name="check_keywords" value="yes" />
70
+ <input type="hidden" name="area" value="default" />
71
+ </form>
72
+ </div>
73
+
74
+
75
+ </div>
76
+
77
+
78
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
79
+
80
+
81
+
82
+
83
+
84
+
85
+ <ul>
86
+ <li class="toctree-l1"><a class="reference internal" href="../../installation.html">Installation</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../../usage.html">Usage</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../../data.html">Data</a></li>
89
+ <li class="toctree-l1"><a class="reference internal" href="../../api/index.html">API</a></li>
90
+ </ul>
91
+
92
+
93
+
94
+ </div>
95
+
96
+ </div>
97
+ </nav>
98
+
99
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
100
+
101
+
102
+ <nav class="wy-nav-top" aria-label="top navigation">
103
+
104
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
105
+ <a href="../../index.html">D-SCRIPT</a>
106
+
107
+ </nav>
108
+
109
+
110
+ <div class="wy-nav-content">
111
+
112
+ <div class="rst-content">
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+ <div role="navigation" aria-label="breadcrumbs navigation">
133
+
134
+ <ul class="wy-breadcrumbs">
135
+
136
+ <li><a href="../../index.html" class="icon icon-home"></a> &raquo;</li>
137
+
138
+ <li><a href="../index.html">Module code</a> &raquo;</li>
139
+
140
+ <li>dscript.utils</li>
141
+
142
+
143
+ <li class="wy-breadcrumbs-aside">
144
+
145
+ </li>
146
+
147
+ </ul>
148
+
149
+
150
+ <hr/>
151
+ </div>
152
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
153
+ <div itemprop="articleBody">
154
+
155
+ <h1>Source code for dscript.utils</h1><div class="highlight"><pre>
156
+ <span></span><span class="kn">import</span> <span class="nn">torch</span>
157
+ <span class="kn">import</span> <span class="nn">torch.utils.data</span>
158
+
159
+ <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
160
+ <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span>
161
+ <span class="kn">import</span> <span class="nn">subprocess</span> <span class="k">as</span> <span class="nn">sp</span>
162
+ <span class="kn">import</span> <span class="nn">sys</span>
163
+ <span class="kn">import</span> <span class="nn">gzip</span> <span class="k">as</span> <span class="nn">gz</span>
164
+ <span class="kn">from</span> <span class="nn">datetime</span> <span class="kn">import</span> <span class="n">datetime</span>
165
+ <span class="kn">from</span> <span class="nn">.fasta</span> <span class="kn">import</span> <span class="n">parse</span>
166
+
167
+ <div class="viewcode-block" id="log"><a class="viewcode-back" href="../../api/index.html#dscript.utils.log">[docs]</a><span class="k">def</span> <span class="nf">log</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="p">):</span>
168
+ <span class="sd">&quot;&quot;&quot;</span>
169
+ <span class="sd"> Log datetime-stamped message to file</span>
170
+ <span class="sd"> </span>
171
+ <span class="sd"> :param msg: Message to log</span>
172
+ <span class="sd"> :param f: Writable file object to log message to</span>
173
+ <span class="sd"> &quot;&quot;&quot;</span>
174
+ <span class="n">timestr</span> <span class="o">=</span> <span class="n">datetime</span><span class="o">.</span><span class="n">utcnow</span><span class="p">()</span><span class="o">.</span><span class="n">isoformat</span><span class="p">(</span><span class="n">sep</span><span class="o">=</span><span class="s1">&#39;-&#39;</span><span class="p">,</span> <span class="n">timespec</span><span class="o">=</span><span class="s1">&#39;milliseconds&#39;</span><span class="p">)</span>
175
+ <span class="n">file</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;[</span><span class="si">{</span><span class="n">timestr</span><span class="si">}</span><span class="s2">] </span><span class="si">{</span><span class="n">msg</span><span class="si">}</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
176
+ <span class="n">file</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span></div>
177
+
178
+ <div class="viewcode-block" id="plot_PR_curve"><a class="viewcode-back" href="../../api/index.html#dscript.utils.plot_PR_curve">[docs]</a><span class="k">def</span> <span class="nf">plot_PR_curve</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">phat</span><span class="p">,</span> <span class="n">saveFile</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
179
+ <span class="sd">&quot;&quot;&quot;</span>
180
+ <span class="sd"> Plot precision-recall curve.</span>
181
+
182
+ <span class="sd"> :param y: Labels</span>
183
+ <span class="sd"> :type y: np.ndarray</span>
184
+ <span class="sd"> :param phat: Predicted probabilities</span>
185
+ <span class="sd"> :type phat: np.ndarray</span>
186
+ <span class="sd"> :param saveFile: File for plot of curve to be saved to</span>
187
+ <span class="sd"> :type saveFile: str</span>
188
+ <span class="sd"> &quot;&quot;&quot;</span>
189
+ <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
190
+ <span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">precision_recall_curve</span><span class="p">,</span> <span class="n">average_precision_score</span>
191
+
192
+ <span class="n">aupr</span> <span class="o">=</span> <span class="n">average_precision_score</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">phat</span><span class="p">)</span>
193
+ <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">precision_recall_curve</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">phat</span><span class="p">)</span>
194
+
195
+ <span class="n">plt</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">recall</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s2">&quot;b&quot;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">)</span>
196
+ <span class="n">plt</span><span class="o">.</span><span class="n">fill_between</span><span class="p">(</span><span class="n">recall</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s2">&quot;b&quot;</span><span class="p">)</span>
197
+ <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">&quot;Recall&quot;</span><span class="p">)</span>
198
+ <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s2">&quot;Precision&quot;</span><span class="p">)</span>
199
+ <span class="n">plt</span><span class="o">.</span><span class="n">ylim</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.05</span><span class="p">])</span>
200
+ <span class="n">plt</span><span class="o">.</span><span class="n">xlim</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">])</span>
201
+ <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Precision-Recall (AUPR: </span><span class="si">{:.3}</span><span class="s2">)&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">aupr</span><span class="p">))</span>
202
+ <span class="k">if</span> <span class="n">saveFile</span><span class="p">:</span>
203
+ <span class="n">plt</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="n">saveFile</span><span class="p">)</span>
204
+ <span class="k">else</span><span class="p">:</span>
205
+ <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span></div>
206
+
207
+
208
+ <div class="viewcode-block" id="plot_ROC_curve"><a class="viewcode-back" href="../../api/index.html#dscript.utils.plot_ROC_curve">[docs]</a><span class="k">def</span> <span class="nf">plot_ROC_curve</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">phat</span><span class="p">,</span> <span class="n">saveFile</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
209
+ <span class="sd">&quot;&quot;&quot;</span>
210
+ <span class="sd"> Plot receiver operating characteristic curve.</span>
211
+
212
+ <span class="sd"> :param y: Labels</span>
213
+ <span class="sd"> :type y: np.ndarray</span>
214
+ <span class="sd"> :param phat: Predicted probabilities</span>
215
+ <span class="sd"> :type phat: np.ndarray</span>
216
+ <span class="sd"> :param saveFile: File for plot of curve to be saved to</span>
217
+ <span class="sd"> :type saveFile: str</span>
218
+ <span class="sd"> &quot;&quot;&quot;</span>
219
+ <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
220
+ <span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">roc_curve</span><span class="p">,</span> <span class="n">roc_auc_score</span>
221
+
222
+ <span class="n">auroc</span> <span class="o">=</span> <span class="n">roc_auc_score</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">phat</span><span class="p">)</span>
223
+
224
+ <span class="n">fpr</span><span class="p">,</span> <span class="n">tpr</span><span class="p">,</span> <span class="n">roc_thresh</span> <span class="o">=</span> <span class="n">roc_curve</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">phat</span><span class="p">)</span>
225
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;AUROC:&quot;</span><span class="p">,</span> <span class="n">auroc</span><span class="p">)</span>
226
+
227
+ <span class="n">plt</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">fpr</span><span class="p">,</span> <span class="n">tpr</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s2">&quot;b&quot;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">)</span>
228
+ <span class="n">plt</span><span class="o">.</span><span class="n">fill_between</span><span class="p">(</span><span class="n">fpr</span><span class="p">,</span> <span class="n">tpr</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s2">&quot;b&quot;</span><span class="p">)</span>
229
+ <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">&quot;FPR&quot;</span><span class="p">)</span>
230
+ <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s2">&quot;TPR&quot;</span><span class="p">)</span>
231
+ <span class="n">plt</span><span class="o">.</span><span class="n">ylim</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.05</span><span class="p">])</span>
232
+ <span class="n">plt</span><span class="o">.</span><span class="n">xlim</span><span class="p">([</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">])</span>
233
+ <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Receiver Operating Characteristic (AUROC: </span><span class="si">{:.3}</span><span class="s2">)&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">auroc</span><span class="p">))</span>
234
+ <span class="k">if</span> <span class="n">saveFile</span><span class="p">:</span>
235
+ <span class="n">plt</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="n">saveFile</span><span class="p">)</span>
236
+ <span class="k">else</span><span class="p">:</span>
237
+ <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span></div>
238
+
239
+
240
+ <div class="viewcode-block" id="RBF"><a class="viewcode-back" href="../../api/index.html#dscript.utils.RBF">[docs]</a><span class="k">def</span> <span class="nf">RBF</span><span class="p">(</span><span class="n">D</span><span class="p">,</span> <span class="n">sigma</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
241
+ <span class="sd">&quot;&quot;&quot;</span>
242
+ <span class="sd"> Convert distance matrix into similarity matrix using Radial Basis Function (RBF) Kernel.</span>
243
+
244
+ <span class="sd"> :math:`RBF(x,x&#39;) = \\exp{\\frac{-(x - x&#39;)^{2}}{2\\sigma^{2}}}`</span>
245
+
246
+ <span class="sd"> :param D: Distance matrix</span>
247
+ <span class="sd"> :type D: np.ndarray</span>
248
+ <span class="sd"> :param sigma: Bandwith of RBF Kernel [default: :math:`\\sqrt{\\text{max}(D)}`]</span>
249
+ <span class="sd"> :type sigma: float</span>
250
+ <span class="sd"> :return: Similarity matrix</span>
251
+ <span class="sd"> :rtype: np.ndarray</span>
252
+ <span class="sd"> &quot;&quot;&quot;</span>
253
+ <span class="n">sigma</span> <span class="o">=</span> <span class="n">sigma</span> <span class="ow">or</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">D</span><span class="p">))</span>
254
+ <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">D</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">sigma</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)))</span></div>
255
+
256
+
257
+ <div class="viewcode-block" id="gpu_mem"><a class="viewcode-back" href="../../api/index.html#dscript.utils.gpu_mem">[docs]</a><span class="k">def</span> <span class="nf">gpu_mem</span><span class="p">(</span><span class="n">device</span><span class="p">):</span>
258
+ <span class="sd">&quot;&quot;&quot;</span>
259
+ <span class="sd"> Get current memory usage for GPU.</span>
260
+
261
+ <span class="sd"> :param device: GPU device number</span>
262
+ <span class="sd"> :type device: int</span>
263
+ <span class="sd"> :return: memory used, memory total</span>
264
+ <span class="sd"> :rtype: int, int</span>
265
+ <span class="sd"> &quot;&quot;&quot;</span>
266
+ <span class="n">result</span> <span class="o">=</span> <span class="n">sp</span><span class="o">.</span><span class="n">check_output</span><span class="p">(</span>
267
+ <span class="p">[</span>
268
+ <span class="s2">&quot;nvidia-smi&quot;</span><span class="p">,</span>
269
+ <span class="s2">&quot;--query-gpu=memory.used,memory.total&quot;</span><span class="p">,</span>
270
+ <span class="s2">&quot;--format=csv,nounits,noheader&quot;</span><span class="p">,</span>
271
+ <span class="s2">&quot;--id=</span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
272
+ <span class="p">],</span>
273
+ <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">,</span>
274
+ <span class="p">)</span>
275
+ <span class="n">gpu_memory</span> <span class="o">=</span> <span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">result</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;,&quot;</span><span class="p">)]</span>
276
+ <span class="k">return</span> <span class="n">gpu_memory</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">gpu_memory</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></div>
277
+
278
+
279
+ <div class="viewcode-block" id="PairedDataset"><a class="viewcode-back" href="../../api/index.html#dscript.utils.PairedDataset">[docs]</a><span class="k">class</span> <span class="nc">PairedDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="p">):</span>
280
+ <span class="sd">&quot;&quot;&quot;</span>
281
+ <span class="sd"> Dataset to be used by the PyTorch data loader for pairs of sequences and their labels.</span>
282
+
283
+ <span class="sd"> :param X0: List of first item in the pair</span>
284
+ <span class="sd"> :param X1: List of second item in the pair</span>
285
+ <span class="sd"> :param Y: List of labels</span>
286
+ <span class="sd"> &quot;&quot;&quot;</span>
287
+ <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X0</span><span class="p">,</span> <span class="n">X1</span><span class="p">,</span> <span class="n">Y</span><span class="p">):</span>
288
+ <span class="bp">self</span><span class="o">.</span><span class="n">X0</span> <span class="o">=</span> <span class="n">X0</span>
289
+ <span class="bp">self</span><span class="o">.</span><span class="n">X1</span> <span class="o">=</span> <span class="n">X1</span>
290
+ <span class="bp">self</span><span class="o">.</span><span class="n">Y</span> <span class="o">=</span> <span class="n">Y</span>
291
+ <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">X0</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">X1</span><span class="p">),</span> <span class="s2">&quot;X0: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">X0</span><span class="p">))</span> <span class="o">+</span> <span class="s2">&quot; X1: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">X1</span><span class="p">))</span> <span class="o">+</span> <span class="s2">&quot; Y: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">Y</span><span class="p">))</span>
292
+ <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">X0</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">Y</span><span class="p">),</span> <span class="s2">&quot;X0: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">X0</span><span class="p">))</span> <span class="o">+</span> <span class="s2">&quot; X1: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">X1</span><span class="p">))</span> <span class="o">+</span> <span class="s2">&quot; Y: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">Y</span><span class="p">))</span>
293
+
294
+ <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
295
+ <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X0</span><span class="p">)</span>
296
+
297
+ <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">i</span><span class="p">):</span>
298
+ <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">X0</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">X1</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">[</span><span class="n">i</span><span class="p">]</span></div>
299
+
300
+
301
+ <div class="viewcode-block" id="collate_paired_sequences"><a class="viewcode-back" href="../../api/index.html#dscript.utils.collate_paired_sequences">[docs]</a><span class="k">def</span> <span class="nf">collate_paired_sequences</span><span class="p">(</span><span class="n">args</span><span class="p">):</span>
302
+ <span class="sd">&quot;&quot;&quot;</span>
303
+ <span class="sd"> Collate function for PyTorch data loader.</span>
304
+ <span class="sd"> &quot;&quot;&quot;</span>
305
+ <span class="n">x0</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">a</span> <span class="ow">in</span> <span class="n">args</span><span class="p">]</span>
306
+ <span class="n">x1</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">a</span> <span class="ow">in</span> <span class="n">args</span><span class="p">]</span>
307
+ <span class="n">y</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="k">for</span> <span class="n">a</span> <span class="ow">in</span> <span class="n">args</span><span class="p">]</span>
308
+ <span class="k">return</span> <span class="n">x0</span><span class="p">,</span> <span class="n">x1</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span></div>
309
+ </pre></div>
310
+
311
+ </div>
312
+
313
+ </div>
314
+ <footer>
315
+
316
+ <hr/>
317
+
318
+ <div role="contentinfo">
319
+ <p>
320
+ &#169; Copyright 2020, Samuel Sledzieski, Rohit Singh.
321
+
322
+ </p>
323
+ </div>
324
+
325
+
326
+
327
+ Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
328
+
329
+ <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
330
+
331
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
332
+
333
+ </footer>
334
+ </div>
335
+ </div>
336
+
337
+ </section>
338
+
339
+ </div>
340
+
341
+
342
+ <script type="text/javascript">
343
+ jQuery(function () {
344
+ SphinxRtdTheme.Navigation.enable(true);
345
+ });
346
+ </script>
347
+
348
+
349
+
350
+
351
+
352
+
353
+ </body>
354
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_modules/index.html ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <!DOCTYPE html>
4
+ <html class="writer-html5" lang="en" >
5
+ <head>
6
+ <meta charset="utf-8" />
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
+
10
+ <title>Overview: module code &mdash; D-SCRIPT v1.0-beta documentation</title>
11
+
12
+
13
+
14
+ <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
15
+ <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+ <!--[if lt IE 9]>
27
+ <script src="../_static/js/html5shiv.min.js"></script>
28
+ <![endif]-->
29
+
30
+
31
+ <script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
32
+ <script src="../_static/jquery.js"></script>
33
+ <script src="../_static/underscore.js"></script>
34
+ <script src="../_static/doctools.js"></script>
35
+
36
+ <script type="text/javascript" src="../_static/js/theme.js"></script>
37
+
38
+
39
+ <link rel="index" title="Index" href="../genindex.html" />
40
+ <link rel="search" title="Search" href="../search.html" />
41
+ </head>
42
+
43
+ <body class="wy-body-for-nav">
44
+
45
+
46
+ <div class="wy-grid-for-nav">
47
+
48
+ <nav data-toggle="wy-nav-shift" class="wy-nav-side">
49
+ <div class="wy-side-scroll">
50
+ <div class="wy-side-nav-search" >
51
+
52
+
53
+
54
+ <a href="../index.html" class="icon icon-home"> D-SCRIPT
55
+
56
+
57
+
58
+ </a>
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+ <div role="search">
67
+ <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
68
+ <input type="text" name="q" placeholder="Search docs" />
69
+ <input type="hidden" name="check_keywords" value="yes" />
70
+ <input type="hidden" name="area" value="default" />
71
+ </form>
72
+ </div>
73
+
74
+
75
+ </div>
76
+
77
+
78
+ <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
79
+
80
+
81
+
82
+
83
+
84
+
85
+ <ul>
86
+ <li class="toctree-l1"><a class="reference internal" href="../installation.html">Installation</a></li>
87
+ <li class="toctree-l1"><a class="reference internal" href="../usage.html">Usage</a></li>
88
+ <li class="toctree-l1"><a class="reference internal" href="../data.html">Data</a></li>
89
+ <li class="toctree-l1"><a class="reference internal" href="../api/index.html">API</a></li>
90
+ </ul>
91
+
92
+
93
+
94
+ </div>
95
+
96
+ </div>
97
+ </nav>
98
+
99
+ <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
100
+
101
+
102
+ <nav class="wy-nav-top" aria-label="top navigation">
103
+
104
+ <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
105
+ <a href="../index.html">D-SCRIPT</a>
106
+
107
+ </nav>
108
+
109
+
110
+ <div class="wy-nav-content">
111
+
112
+ <div class="rst-content">
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+ <div role="navigation" aria-label="breadcrumbs navigation">
133
+
134
+ <ul class="wy-breadcrumbs">
135
+
136
+ <li><a href="../index.html" class="icon icon-home"></a> &raquo;</li>
137
+
138
+ <li>Overview: module code</li>
139
+
140
+
141
+ <li class="wy-breadcrumbs-aside">
142
+
143
+ </li>
144
+
145
+ </ul>
146
+
147
+
148
+ <hr/>
149
+ </div>
150
+ <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
151
+ <div itemprop="articleBody">
152
+
153
+ <h1>All modules for which code is available</h1>
154
+ <ul><li><a href="dscript/alphabets.html">dscript.alphabets</a></li>
155
+ <li><a href="dscript/commands/eval.html">dscript.commands.eval</a></li>
156
+ <li><a href="dscript/commands/train.html">dscript.commands.train</a></li>
157
+ <li><a href="dscript/fasta.html">dscript.fasta</a></li>
158
+ <li><a href="dscript/language_model.html">dscript.language_model</a></li>
159
+ <li><a href="dscript/models/contact.html">dscript.models.contact</a></li>
160
+ <li><a href="dscript/models/embedding.html">dscript.models.embedding</a></li>
161
+ <li><a href="dscript/models/interaction.html">dscript.models.interaction</a></li>
162
+ <li><a href="dscript/pretrained.html">dscript.pretrained</a></li>
163
+ <li><a href="dscript/utils.html">dscript.utils</a></li>
164
+ </ul>
165
+
166
+ </div>
167
+
168
+ </div>
169
+ <footer>
170
+
171
+ <hr/>
172
+
173
+ <div role="contentinfo">
174
+ <p>
175
+ &#169; Copyright 2020, Samuel Sledzieski, Rohit Singh.
176
+
177
+ </p>
178
+ </div>
179
+
180
+
181
+
182
+ Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
183
+
184
+ <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
185
+
186
+ provided by <a href="https://readthedocs.org">Read the Docs</a>.
187
+
188
+ </footer>
189
+ </div>
190
+ </div>
191
+
192
+ </section>
193
+
194
+ </div>
195
+
196
+
197
+ <script type="text/javascript">
198
+ jQuery(function () {
199
+ SphinxRtdTheme.Navigation.enable(true);
200
+ });
201
+ </script>
202
+
203
+
204
+
205
+
206
+
207
+
208
+ </body>
209
+ </html>
samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/api/dscript.commands.rst.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dscript.commands
2
+ ================
3
+
4
+ dscript.commands.predict
5
+ ------------------------
6
+
7
+ See `Prediction <../usage.html#prediction>`_ for full usage details.
8
+
9
+ .. automodule:: dscript.commands.predict
10
+ :members:
11
+ :undoc-members:
12
+ :show-inheritance:
13
+
14
+ dscript.commands.embed
15
+ ----------------------
16
+
17
+ See `Embedding <../usage.html#embedding>`_ for full usage details.
18
+
19
+ .. automodule:: dscript.commands.embed
20
+ :members:
21
+ :undoc-members:
22
+ :show-inheritance:
23
+
24
+ dscript.commands.train
25
+ ----------------------
26
+
27
+ See `Training <../usage.html#training>`_ for full usage details.
28
+
29
+ .. automodule:: dscript.commands.train
30
+ :members:
31
+ :undoc-members:
32
+ :show-inheritance:
33
+
34
+ dscript.commands.eval
35
+ ---------------------
36
+
37
+ See `Evaluation <../usage.html#evaluation>`_ for full usage details.
38
+
39
+ .. automodule:: dscript.commands.eval
40
+ :members:
41
+ :undoc-members:
42
+ :show-inheritance:
samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/api/dscript.models.rst.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dscript.models
2
+ ==============
3
+
4
+ dscript.models.embedding
5
+ ------------------------
6
+
7
+ .. automodule:: dscript.models.embedding
8
+ :members:
9
+ :undoc-members:
10
+ :show-inheritance:
11
+
12
+ dscript.models.contact
13
+ ----------------------
14
+
15
+ .. automodule:: dscript.models.contact
16
+ :members:
17
+ :undoc-members:
18
+ :show-inheritance:
19
+
20
+ dscript.models.interaction
21
+ --------------------------
22
+
23
+ .. automodule:: dscript.models.interaction
24
+ :members:
25
+ :undoc-members:
26
+ :show-inheritance:
samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/api/index.rst.txt ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ API
2
+ ===
3
+
4
+ .. toctree::
5
+ :maxdepth: 4
6
+
7
+ dscript.commands
8
+ dscript.models
9
+
10
+ dscript.alphabets
11
+ -----------------
12
+
13
+ .. automodule:: dscript.alphabets
14
+ :members:
15
+ :undoc-members:
16
+ :show-inheritance:
17
+
18
+ dscript.fasta
19
+ -------------
20
+
21
+ .. automodule:: dscript.fasta
22
+ :members:
23
+ :undoc-members:
24
+ :show-inheritance:
25
+
26
+ dscript.language\_model
27
+ -----------------------
28
+
29
+ .. automodule:: dscript.language_model
30
+ :members:
31
+ :undoc-members:
32
+ :show-inheritance:
33
+
34
+ dscript.pretrained
35
+ ------------------
36
+
37
+ .. automodule:: dscript.pretrained
38
+ :members:
39
+ :undoc-members:
40
+ :show-inheritance:
41
+
42
+ dscript.utils
43
+ -------------
44
+
45
+ .. automodule:: dscript.utils
46
+ :members:
47
+ :undoc-members:
48
+ :show-inheritance:
samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/data.rst.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Data
2
+ ====
3
+
4
+ Trained Models
5
+ --------------
6
+
7
+ - `Bepler & Berger language model <http://cb.csail.mit.edu/cb/dscript/data/models/lm_v1.sav>`_
8
+ - `Human data trained model <http://cb.csail.mit.edu/cb/dscript/data/models/human_v1.sav>`_
9
+
10
+ Sample Data
11
+ -----------
12
+
13
+ Sequences
14
+ ~~~~~~~~~
15
+
16
+ - `Human`_
17
+ - `Mouse`_
18
+ - `Fly`_
19
+ - `Yeast`_
20
+ - `Worm`_
21
+
22
+ Interactions
23
+ ~~~~~~~~~~~~
24
+
25
+ - `Human Train`_
26
+ - `Human Test`_
27
+ - `Mouse Test`_
28
+ - `Fly Test`_
29
+ - `Yeast Test`_
30
+ - `Worm Test`_
31
+
32
+ .. _`Human`: https://github.com/samsledje/D-SCRIPT/blob/main/data/seqs/human.fasta
33
+ .. _`Mouse`: https://github.com/samsledje/D-SCRIPT/blob/main/data/seqs/mouse.fasta
34
+ .. _`Fly`: https://github.com/samsledje/D-SCRIPT/blob/main/data/seqs/fly.fasta
35
+ .. _`Yeast`: https://github.com/samsledje/D-SCRIPT/blob/main/data/seqs/yeast.fasta
36
+ .. _`Worm`: https://github.com/samsledje/D-SCRIPT/blob/main/data/seqs/worm.fasta
37
+ .. _`Human Train`: https://github.com/samsledje/D-SCRIPT/blob/main/data/pairs/human_train.tsv
38
+ .. _`Human Test`: https://github.com/samsledje/D-SCRIPT/blob/main/data/pairs/human_test.tsv
39
+ .. _`Mouse Test`: https://github.com/samsledje/D-SCRIPT/blob/main/data/pairs/mouse_test.tsv
40
+ .. _`Fly Test`: https://github.com/samsledje/D-SCRIPT/blob/main/data/pairs/fly_test.tsv
41
+ .. _`Yeast Test`: https://github.com/samsledje/D-SCRIPT/blob/main/data/pairs/yeast_test.tsv
42
+ .. _`Worm Test`: https://github.com/samsledje/D-SCRIPT/blob/main/data/pairs/worm_test.tsv
43
+
samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/index.rst.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ D-SCRIPT: Deep Learning PPI Prediction
2
+ =======================================
3
+
4
+ - `D-SCRIPT Home Page`_
5
+
6
+ - `Quick Start <usage.html#quick-start>`_
7
+
8
+ D-SCRIPT is a deep learning method for predicting a physical interaction between two proteins given just their sequences.
9
+ It generalizes well to new species and is robust to limitations in training data size.
10
+ Its design reflects the intuition that for two proteins to physically interact, a subset of amino acids from each protein should be in contact with the other.
11
+ The intermediate stages of D-SCRIPT directly implement this intuition, with the penultimate stage in D-SCRIPT being a rough estimate of the inter-protein
12
+ contact map of the protein dimer. This structurally-motivated design enhances the interpretability of the results and, since structure is more conserved
13
+ evolutionarily than sequence, improves generalizability across species.
14
+
15
+ If you use D-SCRIPT, please cite "Sequence-based prediction of protein-protein interactions: a structure-aware interpetable deep learning model"
16
+ by `Sam Sledzieski`_, `Rohit Singh`_, `Lenore Cowen`_, and `Bonnie Berger`_ [link TBD].
17
+
18
+ .. _`D-SCRIPT Home Page`: http://dscript.csail.mit.edu
19
+ .. _`Sam Sledzieski`: http://samsledje.github.io/
20
+ .. _`Rohit Singh`: http://people.csail.mit.edu/rsingh/
21
+ .. _`Lenore Cowen`: http://www.cs.tufts.edu/~cowen/
22
+ .. _`Bonnie Berger`: http://people.csail.mit.edu/bab/
23
+
24
+ Table of contents
25
+ =================
26
+
27
+ .. toctree::
28
+ :maxdepth: 1
29
+
30
+ installation
31
+ usage
32
+ data
33
+ api/index
34
+
35
+ Indices and tables
36
+ ==================
37
+
38
+ * :ref:`genindex`
39
+ * :ref:`modindex`
samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/installation.rst.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Installation
2
+ ============
3
+
4
+ Requirements
5
+ ------------
6
+ - python 3.7
7
+ - pytorch 1.5
8
+ - h5py
9
+ - matplotlib
10
+ - numpy
11
+ - pandas
12
+ - scikit-learn
13
+ - scipy
14
+ - seaborn
15
+ - setuptools
16
+ - tqdm
17
+
18
+ Optional GPU support: CUDA Toolkit, cuDNN
19
+
20
+ Set up environment
21
+ ------------------
22
+
23
+ .. code-block:: bash
24
+
25
+ $ git clone https://github.com/samsledje/D-SCRIPT.git
26
+
27
+ $ cd D-SCRIPT
28
+
29
+ $ conda env create --file environment.yml # Edit this file to change CUDA version if necessary
30
+
31
+ $ conda activate dscript
32
+
33
+ Install from pip
34
+ ----------------
35
+
36
+ .. code-block:: bash
37
+
38
+ pip install dscript
39
+
40
+ Build from source
41
+ -----------------
42
+
43
+ .. code-block:: bash
44
+
45
+ $ git clone https://github.com/samsledje/D-SCRIPT.git
46
+
47
+ $ cd D-SCRIPT
48
+
49
+ $ python setup.py build; python setup.py install
samsledje-D-SCRIPT-8a55490/docs/build/html/_sources/usage.rst.txt ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Usage
2
+ =====
3
+
4
+ Quick Start
5
+ ~~~~~~~~~~~
6
+
7
+ Predict a new network using a trained model
8
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
9
+
10
+ Pre-trained models can be downloaded from [TBD].
11
+ Candidate pairs should be in tab-separated (``.tsv``) format with no header, and columns for [protein name 1], [protein name 2].
12
+ Optionally, a third column with [label] can be provided, so predictions can be made using training or test data files (but the label will not affect the predictions).
13
+
14
+ .. code-block:: bash
15
+
16
+ dscript predict --pairs [input data] --seqs [sequences, .fasta format] --model [model file]
17
+
18
+ Embed sequences with language model
19
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20
+
21
+ Sequences should be in ``.fasta`` format.
22
+
23
+ .. code-block:: bash
24
+
25
+ dscript embed --seqs [sequences] --outfile [embedding file]
26
+
27
+ Train and save a model
28
+ ^^^^^^^^^^^^^^^^^^^^^^
29
+
30
+ Training and validation data should be in tab-separated (``.tsv``) format with no header, and columns for [protein name 1], [protein name 2], [label].
31
+
32
+ .. code-block:: bash
33
+
34
+ dscript train --train [training data] --val [validation data] --embedding [embedding file] --save-prefix [prefix]
35
+
36
+
37
+ Evaluate a trained model
38
+ ^^^^^^^^^^^^^^^^^^^^^^^^
39
+
40
+ .. code-block:: bash
41
+
42
+ dscript eval --model [model file] --test [test data] --embedding [embedding file] --outfile [result file]
43
+
44
+
45
+ Prediction
46
+ ~~~~~~~~~~
47
+
48
+ .. code-block:: bash
49
+
50
+ usage: dscript predict [-h] --pairs PAIRS --model MODEL [--seqs SEQS]
51
+ [--embeddings EMBEDDINGS] [-o OUTFILE] [-d DEVICE]
52
+ [--thresh THRESH]
53
+
54
+ Make new predictions with a pre-trained model. One of --seqs and --embeddings is required.
55
+
56
+ optional arguments:
57
+ -h, --help show this help message and exit
58
+ --pairs PAIRS Candidate protein pairs to predict
59
+ --model MODEL Pretrained Model
60
+ --seqs SEQS Protein sequences in .fasta format
61
+ --embeddings EMBEDDINGS
62
+ h5 file with embedded sequences
63
+ -o OUTFILE, --outfile OUTFILE
64
+ File for predictions
65
+ -d DEVICE, --device DEVICE
66
+ Compute device to use
67
+ --thresh THRESH Positive prediction threshold - used to store contact
68
+ maps and predictions in a separate file. [default:
69
+ 0.5]
70
+
71
+ Embedding
72
+ ~~~~~~~~~
73
+
74
+ .. code-block:: bash
75
+
76
+ usage: dscript embed [-h] --seqs SEQS --outfile OUTFILE [-d DEVICE]
77
+
78
+ Generate new embeddings using pre-trained language model
79
+
80
+ optional arguments:
81
+ -h, --help show this help message and exit
82
+ --seqs SEQS Sequences to be embedded
83
+ --outfile OUTFILE h5 file to write results
84
+ -d DEVICE, --device DEVICE
85
+ Compute device to use
86
+
87
+ Training
88
+ ~~~~~~~~
89
+
90
+ .. code-block:: bash
91
+
92
+ usage: dscript train [-h] --train TRAIN --val VAL --embedding EMBEDDING
93
+ [--augment] [--projection-dim PROJECTION_DIM]
94
+ [--dropout-p DROPOUT_P] [--hidden-dim HIDDEN_DIM]
95
+ [--kernel-width KERNEL_WIDTH] [--use-w]
96
+ [--pool-width POOL_WIDTH]
97
+ [--negative-ratio NEGATIVE_RATIO]
98
+ [--epoch-scale EPOCH_SCALE] [--num-epochs NUM_EPOCHS]
99
+ [--batch-size BATCH_SIZE] [--weight-decay WEIGHT_DECAY]
100
+ [--lr LR] [--lambda LAMBDA_] [-o OUTFILE]
101
+ [--save-prefix SAVE_PREFIX] [-d DEVICE]
102
+ [--checkpoint CHECKPOINT]
103
+
104
+ Train a new model
105
+
106
+ optional arguments:
107
+ -h, --help show this help message and exit
108
+
109
+ Data:
110
+ --train TRAIN Training data
111
+ --val VAL Validation data
112
+ --embedding EMBEDDING
113
+ h5 file with embedded sequences
114
+ --augment Set flag to augment data by adding (B A) for all pairs
115
+ (A B)
116
+
117
+ Projection Module:
118
+ --projection-dim PROJECTION_DIM
119
+ Dimension of embedding projection layer (default: 100)
120
+ --dropout-p DROPOUT_P
121
+ Parameter p for embedding dropout layer (default: 0.5)
122
+
123
+ Contact Module:
124
+ --hidden-dim HIDDEN_DIM
125
+ Number of hidden units for comparison layer in contact
126
+ prediction (default: 50)
127
+ --kernel-width KERNEL_WIDTH
128
+ Width of convolutional filter for contact prediction
129
+ (default: 7)
130
+
131
+ Interaction Module:
132
+ --use-w Use weight matrix in interaction prediction model
133
+ --pool-width POOL_WIDTH
134
+ Size of max-pool in interaction model (default: 9)
135
+
136
+ Training:
137
+ --negative-ratio NEGATIVE_RATIO
138
+ Number of negative training samples for each positive
139
+ training sample (default: 10)
140
+ --epoch-scale EPOCH_SCALE
141
+ Report heldout performance every this many epochs
142
+ (default: 5)
143
+ --num-epochs NUM_EPOCHS
144
+ Number of epochs (default: 100)
145
+ --batch-size BATCH_SIZE
146
+ Minibatch size (default: 25)
147
+ --weight-decay WEIGHT_DECAY
148
+ L2 regularization (default: 0)
149
+ --lr LR Learning rate (default: 0.001)
150
+ --lambda LAMBDA_ Weight on the similarity objective (default: 0.35)
151
+
152
+ Output and Device:
153
+ -o OUTPUT, --output OUTPUT
154
+ Output file path (default: stdout)
155
+ --save-prefix SAVE_PREFIX
156
+ Path prefix for saving models
157
+ -d DEVICE, --device DEVICE
158
+ Compute device to use
159
+ --checkpoint CHECKPOINT
160
+ Checkpoint model to start training from``
161
+
162
+ Evaluation
163
+ ~~~~~~~~~~
164
+
165
+ .. code-block:: bash
166
+
167
+ usage: dscript eval [-h] --model MODEL --test TEST --embedding EMBEDDING
168
+ [-o OUTFILE] [-d DEVICE]
169
+
170
+ Evaluate a trained model
171
+
172
+ optional arguments:
173
+ -h, --help show this help message and exit
174
+ --model MODEL Trained prediction model
175
+ --test TEST Test Data
176
+ --embedding EMBEDDING
177
+ h5 file with embedded sequences
178
+ -o OUTFILE, --outfile OUTFILE
179
+ Output file to write results
180
+ -d DEVICE, --device DEVICE
181
+ Compute device to use