Upload 686 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- .github/CODEOWNERS +2 -0
- .gitignore +16 -0
- .ipynb_checkpoints/untitled-checkpoint.py +0 -0
- .rosetta-ci/.gitignore +3 -0
- .rosetta-ci/benchmark.py +410 -0
- .rosetta-ci/benchmark.template.ini +40 -0
- .rosetta-ci/hpc_drivers/__init__.py +5 -0
- .rosetta-ci/hpc_drivers/base.py +210 -0
- .rosetta-ci/hpc_drivers/multicore.py +184 -0
- .rosetta-ci/hpc_drivers/slurm.py +176 -0
- .rosetta-ci/test-sets.yaml +65 -0
- .rosetta-ci/tests/__init__.py +765 -0
- .rosetta-ci/tests/rfd.py +111 -0
- .rosetta-ci/tests/self.md +6 -0
- .rosetta-ci/tests/self.py +209 -0
- END +7 -0
- LICENSE +30 -0
- README.md +514 -1
- appverifUI.dll +0 -0
- config/inference/base.yaml +136 -0
- config/inference/symmetry.yaml +26 -0
- docker/Dockerfile +50 -0
- env/SE3Transformer/.dockerignore +123 -0
- env/SE3Transformer/.gitignore +121 -0
- env/SE3Transformer/Dockerfile +58 -0
- env/SE3Transformer/LICENSE +7 -0
- env/SE3Transformer/NOTICE +7 -0
- env/SE3Transformer/README.md +580 -0
- env/SE3Transformer/build/lib/se3_transformer/__init__.py +0 -0
- env/SE3Transformer/build/lib/se3_transformer/data_loading/__init__.py +1 -0
- env/SE3Transformer/build/lib/se3_transformer/data_loading/data_module.py +63 -0
- env/SE3Transformer/build/lib/se3_transformer/data_loading/qm9.py +173 -0
- env/SE3Transformer/build/lib/se3_transformer/model/__init__.py +2 -0
- env/SE3Transformer/build/lib/se3_transformer/model/basis.py +178 -0
- env/SE3Transformer/build/lib/se3_transformer/model/fiber.py +144 -0
- env/SE3Transformer/build/lib/se3_transformer/model/layers/__init__.py +5 -0
- env/SE3Transformer/build/lib/se3_transformer/model/layers/attention.py +180 -0
- env/SE3Transformer/build/lib/se3_transformer/model/layers/convolution.py +336 -0
- env/SE3Transformer/build/lib/se3_transformer/model/layers/linear.py +59 -0
- env/SE3Transformer/build/lib/se3_transformer/model/layers/norm.py +83 -0
- env/SE3Transformer/build/lib/se3_transformer/model/layers/pooling.py +53 -0
- env/SE3Transformer/build/lib/se3_transformer/model/transformer.py +222 -0
- env/SE3Transformer/build/lib/se3_transformer/runtime/__init__.py +0 -0
- env/SE3Transformer/build/lib/se3_transformer/runtime/arguments.py +70 -0
- env/SE3Transformer/build/lib/se3_transformer/runtime/callbacks.py +160 -0
- env/SE3Transformer/build/lib/se3_transformer/runtime/gpu_affinity.py +325 -0
- env/SE3Transformer/build/lib/se3_transformer/runtime/inference.py +131 -0
- env/SE3Transformer/build/lib/se3_transformer/runtime/loggers.py +134 -0
- env/SE3Transformer/build/lib/se3_transformer/runtime/metrics.py +83 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
env/SE3Transformer/images/se3-transformer.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
img/diffusion_protein_gradient_2.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
pyrosetta-2023.14+release.7132bdc754a-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
.github/CODEOWNERS
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Benchmark scripts
|
2 |
+
/.rosetta-ci @lyskov
|
.gitignore
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.py[cod]
|
2 |
+
rfdiffusion.egg-info
|
3 |
+
|
4 |
+
models/
|
5 |
+
schedules/
|
6 |
+
|
7 |
+
examples/ppi_scaffolds
|
8 |
+
|
9 |
+
tests/.results.json
|
10 |
+
tests/input_pdbs
|
11 |
+
tests/outputs
|
12 |
+
tests/ppi_scaffolds
|
13 |
+
tests/reference_outputs/
|
14 |
+
tests/target_folds
|
15 |
+
tests/tim_barrel_scaffold
|
16 |
+
tests/tests_*
|
.ipynb_checkpoints/untitled-checkpoint.py
ADDED
File without changes
|
.rosetta-ci/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
results/
|
3 |
+
benchmark.ubuntu.ini
|
.rosetta-ci/benchmark.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# :noTabs=true:
|
4 |
+
|
5 |
+
# (c) Copyright Rosetta Commons Member Institutions.
|
6 |
+
# (c) This file is part of the Rosetta software suite and is made available under license.
|
7 |
+
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
|
8 |
+
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
|
9 |
+
# (c) addressed to University of Washington CoMotion, email: license@uw.edu.
|
10 |
+
|
11 |
+
## @file benchmark.py
|
12 |
+
## @brief Run arbitrary Rosetta testing script
|
13 |
+
## @author Sergey Lyskov
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import os, os.path, sys, shutil, json, platform, re
|
18 |
+
import codecs
|
19 |
+
|
20 |
+
from importlib.machinery import SourceFileLoader
|
21 |
+
|
22 |
+
from configparser import ConfigParser, ExtendedInterpolation
|
23 |
+
import argparse
|
24 |
+
|
25 |
+
from tests import * # execute, Tests states and key names
|
26 |
+
from hpc_drivers import *
|
27 |
+
|
28 |
+
|
29 |
+
# Calculating value of Platform dict
|
30 |
+
Platform = {}
|
31 |
+
if sys.platform.startswith("linux"):
|
32 |
+
Platform['os'] = 'ubuntu' if os.path.isfile('/etc/lsb-release') and 'Ubuntu' in open('/etc/lsb-release').read() else 'linux' # can be linux1, linux2, etc
|
33 |
+
elif sys.platform == "darwin" : Platform['os'] = 'mac'
|
34 |
+
elif sys.platform == "cygwin" : Platform['os'] = 'cygwin'
|
35 |
+
elif sys.platform == "win32" : Platform['os'] = 'windows'
|
36 |
+
else: Platform['os'] = 'unknown'
|
37 |
+
|
38 |
+
#Platform['arch'] = platform.architecture()[0][:2] # PlatformBits
|
39 |
+
Platform['compiler'] = 'gcc' if Platform['os'] == 'linux' else 'clang'
|
40 |
+
|
41 |
+
Platform['python'] = sys.executable
|
42 |
+
|
43 |
+
|
44 |
+
def load_python_source_from_file(module_name, module_path):
|
45 |
+
''' replacment for deprecated imp.load_source
|
46 |
+
'''
|
47 |
+
return SourceFileLoader(module_name, module_path).load_module()
|
48 |
+
|
49 |
+
|
50 |
+
class Setup(object):
|
51 |
+
__slots__ = 'test working_dir platform config compare debug'.split() # version daemon path_to_previous_test
|
52 |
+
def __init__(self, **attrs):
|
53 |
+
#self.daemon = True
|
54 |
+
for k, v in attrs.items():
|
55 |
+
if k in self.__slots__: setattr(self, k, v)
|
56 |
+
|
57 |
+
|
58 |
+
def setup_from_options(options):
|
59 |
+
''' Create Setup object based on user supplied options, config files and auto-detection
|
60 |
+
'''
|
61 |
+
platform = dict(Platform)
|
62 |
+
|
63 |
+
if options.suffix: options.suffix = '.' + options.suffix
|
64 |
+
|
65 |
+
platform['extras'] = options.extras.split(',') if options.extras else []
|
66 |
+
platform['python'] = options.python
|
67 |
+
#platform['options'] = json.loads( options.options ) if options.options else {}
|
68 |
+
|
69 |
+
if options.memory: memory = options.memory
|
70 |
+
elif platform['os'] in ['linux', 'ubuntu']: memory = int( execute('Getting memory info...', 'free -m', terminate_on_failure=False, silent=True, silence_output_on_errors=True, return_='output').split('\n')[1].split()[1]) // 1024
|
71 |
+
elif platform['os'] == 'mac': memory = int( execute('Getting memory info...', 'sysctl -a | grep hw.memsize', terminate_on_failure=False, silent=True, silence_output_on_errors=True, return_='output').split()[1]) // 1024 // 1024 // 1024
|
72 |
+
|
73 |
+
platform['compiler'] = options.compiler
|
74 |
+
|
75 |
+
if os.path.isfile(options.config):
|
76 |
+
with open(options.config) as f:
|
77 |
+
if '%(here)s' in f.read():
|
78 |
+
print(f"\n\n>>> ERROR file `{options.config}` seems to be in outdated format! Please use benchmark.template.ini to update it.")
|
79 |
+
sys.exit(1)
|
80 |
+
|
81 |
+
user_config = ConfigParser(
|
82 |
+
dict(
|
83 |
+
_here_ = os.path.abspath('./'),
|
84 |
+
_user_home_ = os.environ['HOME']
|
85 |
+
),
|
86 |
+
interpolation = ExtendedInterpolation()
|
87 |
+
)
|
88 |
+
|
89 |
+
with open(options.config) as f: user_config.readfp(f)
|
90 |
+
|
91 |
+
else:
|
92 |
+
print(f"\n\n>>> Config file `{options.config}` not found. You may want to manually copy `benchmark.ini.template` to `{options.config}` and edit the settings\n\n")
|
93 |
+
user_config = ConfigParser()
|
94 |
+
user_config.set('main', 'cpu_count', '1')
|
95 |
+
user_config.set('main', 'hpc_driver', 'MultiCore')
|
96 |
+
user_config.set('main', 'branch', 'unknown')
|
97 |
+
user_config.set('main', 'revision', '42')
|
98 |
+
user_config.set('main', 'user_name', 'Jane Roe')
|
99 |
+
user_config.set('main', 'user_email', 'jane.roe@university.edu')
|
100 |
+
user_config.add_section('main')
|
101 |
+
|
102 |
+
if options.jobs: user_config.set('main', 'cpu_count', str(options.jobs) )
|
103 |
+
user_config.set('main', 'memory', str(memory) )
|
104 |
+
|
105 |
+
if options.mount:
|
106 |
+
for m in options.mount:
|
107 |
+
key, _, path = m.partition(':')
|
108 |
+
user_config.set('mount', key, path)
|
109 |
+
|
110 |
+
#config = Config.items('config')
|
111 |
+
#for section in config.sections(): print('Config section: ', section, dict(config.items(section)))
|
112 |
+
#config = { section: dict(Config.items(section)) for section in Config.sections() }
|
113 |
+
|
114 |
+
config = { k : d for k, d in user_config['main'].items() if k not in user_config[user_config.default_section] }
|
115 |
+
config['mounts'] = { k : d for k, d in user_config['mount'].items() if k not in user_config[user_config.default_section] }
|
116 |
+
|
117 |
+
#print(json.dumps(config, sort_keys=True, indent=2)); sys.exit(1)
|
118 |
+
|
119 |
+
#config.update( config.pop('config').items() )
|
120 |
+
|
121 |
+
config = dict(config,
|
122 |
+
cpu_count = user_config.getint('main', 'cpu_count'),
|
123 |
+
memory = memory,
|
124 |
+
revision = user_config.getint('main', 'revision'),
|
125 |
+
emulation=True,
|
126 |
+
) # debug=options.debug,
|
127 |
+
|
128 |
+
if 'results_root' not in config: config['results_root'] = os.path.abspath('./results/')
|
129 |
+
|
130 |
+
if 'prefix' in config:
|
131 |
+
assert os.path.isabs( config['prefix'] ), f'ERROR: `prefix` path must be absolute! Got: {config["prefix"]}'
|
132 |
+
|
133 |
+
else: config['prefix'] = os.path.abspath( config['results_root'] + '/prefix')
|
134 |
+
|
135 |
+
config['merge_head'] = options.merge_head
|
136 |
+
config['merge_base'] = options.merge_base
|
137 |
+
|
138 |
+
if options.skip_compile is not None: config['skip_compile'] = options.skip_compile
|
139 |
+
|
140 |
+
#print(f'Results path: {config["results_root"]}')
|
141 |
+
#print('Config:{}, Platform:{}'.format(json.dumps(config, sort_keys=True, indent=2), Platform))
|
142 |
+
|
143 |
+
if options.compare: print('Comparing tests {} with suffixes: {}'.format(options.args, options.compare) )
|
144 |
+
else: print('Running tests: {}'.format(options.args) )
|
145 |
+
|
146 |
+
if len(options.args) != 1: print('Error: Single test-name-to-run should be supplied!'); sys.exit(1)
|
147 |
+
else:
|
148 |
+
test = options.args[0]
|
149 |
+
if test.startswith('tests/'): test = test.partition('tests/')[2][:-3] # removing dir prefix and .py suffix
|
150 |
+
|
151 |
+
if options.compare:
|
152 |
+
compare = options.compare[0], options.compare[1] # (this test suffix, previous test suffix)
|
153 |
+
working_dir = os.path.abspath( config['results_root'] + f'/{platform["os"]}.{test}' ) # will be a root dir with sub-dirs (options.compare[0], options.compare[1])
|
154 |
+
else:
|
155 |
+
compare = None
|
156 |
+
working_dir = os.path.abspath( config['results_root'] + f'/{platform["os"]}.{test}{options.suffix}' )
|
157 |
+
|
158 |
+
|
159 |
+
if os.path.isdir(working_dir): shutil.rmtree(working_dir); #print('Removing old job dir %s...' % working_dir) # remove old dir if any
|
160 |
+
os.makedirs(working_dir)
|
161 |
+
|
162 |
+
setup = Setup(
|
163 |
+
test = test,
|
164 |
+
working_dir = working_dir,
|
165 |
+
platform = platform,
|
166 |
+
config = config,
|
167 |
+
compare = compare,
|
168 |
+
debug = options.debug,
|
169 |
+
#daemon = False,
|
170 |
+
)
|
171 |
+
|
172 |
+
setup_as_json = json.dumps( { k : getattr(setup, k) for k in setup.__slots__}, sort_keys=True, indent=2)
|
173 |
+
with open(working_dir + '/.setup.json', 'w') as f: f.write(setup_as_json)
|
174 |
+
|
175 |
+
#print(f'Detected hardware platform: {Platform}')
|
176 |
+
print(f'Setup: {setup_as_json}')
|
177 |
+
return setup
|
178 |
+
|
179 |
+
|
180 |
+
def truncate_log(log):
|
181 |
+
_max_log_size_ = 1024*1024*1
|
182 |
+
_max_line_size_ = _max_log_size_ // 2
|
183 |
+
|
184 |
+
if len(log) > _max_log_size_:
|
185 |
+
new = log
|
186 |
+
lines = log.split('\n')
|
187 |
+
|
188 |
+
if len(lines) > 256:
|
189 |
+
new_lines = lines[:32] + ['...truncated...'] + lines[-128:]
|
190 |
+
new = '\n'.join(new_lines)
|
191 |
+
|
192 |
+
if len(new) > _max_log_size_: # special case for Ninja logs that does not use \n
|
193 |
+
lines = re.split(r'[\r\n]*', log) #t.log.split('\r')
|
194 |
+
if len(lines) > 256: new = '\n'.join( lines[:32] + ['...truncated...'] + lines[-128:] )
|
195 |
+
|
196 |
+
if len(new) > _max_log_size_: # going to try to truncate each individual line...
|
197 |
+
print(f'Trying to truncate log line-by-line...')
|
198 |
+
new = '\n'.join( (
|
199 |
+
( line[:_max_line_size_//3] + '...truncated...' + line[-_max_line_size_//3:] ) if line > _max_line_size_ else line
|
200 |
+
for line in new_lines ) )
|
201 |
+
|
202 |
+
if len(new) > _max_log_size_: # fall-back strategy in case all of the above failed...
|
203 |
+
print(f'WARNING: could not truncate log line-by-line, falling back to raw truncate...')
|
204 |
+
new = 'WARNING: could not truncate test log line-by-line, falling back to raw truncate!\n...truncated...\n' + ( '\n'.join(lines) )[-_max_log_size_+256:]
|
205 |
+
|
206 |
+
print( 'Trunacting test output log: {0}MiB --> {1}MiB'.format(len(log)/1024/1024, len(new)/1024/1024) )
|
207 |
+
|
208 |
+
log = new
|
209 |
+
|
210 |
+
return log
|
211 |
+
|
212 |
+
def truncate_results_logs(results):
|
213 |
+
results[_LogKey_] = truncate_log( results[_LogKey_] )
|
214 |
+
if _ResultsKey_ in results and _TestsKey_ in results[_ResultsKey_]:
|
215 |
+
tests = results[_ResultsKey_][_TestsKey_]
|
216 |
+
for test in tests:
|
217 |
+
tests[test][_LogKey_] = truncate_log( tests[test][_LogKey_] )
|
218 |
+
|
219 |
+
|
220 |
+
def find_test_description(test_name, test_script_file_name):
|
221 |
+
''' return content of test-description file if any or None if no description was found
|
222 |
+
'''
|
223 |
+
|
224 |
+
def find_description_file(prefix, test_name):
|
225 |
+
fname = prefix + test_name + '.md'
|
226 |
+
if os.path.isfile(fname): return fname
|
227 |
+
return prefix + 'md'
|
228 |
+
|
229 |
+
description_file_name = find_description_file( test_script_file_name[:-len('command.py')] + 'description.', test_name) if test_script_file_name.endswith('/command.py') else find_description_file(test_script_file_name[:-len('py')], test_name)
|
230 |
+
|
231 |
+
if description_file_name and os.path.isfile(description_file_name):
|
232 |
+
print(f'Found test suite description in file: {description_file_name!r}')
|
233 |
+
with open(description_file_name, encoding='utf-8', errors='backslashreplace') as f: description = f.read()
|
234 |
+
return description
|
235 |
+
|
236 |
+
else: return None
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
def run_test(setup):
|
241 |
+
#print(f'{setup!r}')
|
242 |
+
suite, rest = setup.test.split('.'), []
|
243 |
+
while suite:
|
244 |
+
#print( f'suite: {suite}, test: {rest}' )
|
245 |
+
|
246 |
+
file_name = '/'.join( ['tests'] + suite ) + '.py'
|
247 |
+
if os.path.isfile(file_name): break
|
248 |
+
|
249 |
+
file_name = '/'.join( ['tests'] + suite ) + '/command.py'
|
250 |
+
if os.path.isfile(file_name): break
|
251 |
+
|
252 |
+
rest.insert(0, suite.pop())
|
253 |
+
|
254 |
+
|
255 |
+
test = '.'.join( suite + rest )
|
256 |
+
test_name = '.'.join(rest)
|
257 |
+
|
258 |
+
print( f'Loading test from: {file_name}, suite+test: {test!r}, test: {test_name!r}' )
|
259 |
+
#test_suite = imp.load_source('test_suite', file_name)
|
260 |
+
test_suite = load_python_source_from_file('test_suite', file_name)
|
261 |
+
|
262 |
+
test_description = find_test_description(test_name, file_name)
|
263 |
+
|
264 |
+
if setup.compare:
|
265 |
+
#working_dir_1 = os.path.abspath( config['results_root'] + f'/{Platform["os"]}.{test}.{Options.compare[0]}' )
|
266 |
+
working_dir_1 = setup.working_dir + f'/{setup.compare[0]}'
|
267 |
+
|
268 |
+
working_dir_2 = setup.compare[1] and ( setup.working_dir + f'/{setup.compare[1]}' )
|
269 |
+
res_2_json_file_path = setup.compare[1] and f'{working_dir_2}/.execution.results.json'
|
270 |
+
|
271 |
+
with open(working_dir_1 + '/.execution.results.json') as f: res_1 = json.load(f).get(_ResultsKey_)
|
272 |
+
|
273 |
+
if setup.compare[1] and ( not os.path.isfile(res_2_json_file_path) ):
|
274 |
+
setup.compare[1] = None
|
275 |
+
state_override = _S_failed_
|
276 |
+
else:
|
277 |
+
state_override = None
|
278 |
+
|
279 |
+
if setup.compare[1] == None: res_2, working_dir_2 = None, None
|
280 |
+
else:
|
281 |
+
with open(res_2_json_file_path) as f: res_2 = json.load(f).get(_ResultsKey_)
|
282 |
+
|
283 |
+
res = test_suite.compare(test, res_1, working_dir_1, res_2, working_dir_2)
|
284 |
+
|
285 |
+
if state_override:
|
286 |
+
log_prefix = \
|
287 |
+
f'WARNING: Previous test results does not have `.execution.results.json` file, so comparision with None was performed instead!\n' \
|
288 |
+
f'WARNING: Overriding calcualted test state `{res[_StateKey_]}` → `{_S_failed_}`...\n\n'
|
289 |
+
|
290 |
+
res[_LogKey_] = log_prefix + res[_LogKey_]
|
291 |
+
res[_StateKey_] = _S_failed_
|
292 |
+
|
293 |
+
|
294 |
+
# # Caution! Some of the strings in the result object may be unicode. Be robust to unicode in the log messages.
|
295 |
+
# with codecs.open(setup.working_dir+'/.comparison.log.txt', 'w', encoding='utf-8', errors='replace') as f: f.write( truncate_log( res[_LogKey_] ) )
|
296 |
+
# res[_LogKey_] = truncate_log( res[_LogKey_] )
|
297 |
+
|
298 |
+
# # Caution! Some of the strings in the result object may be unicode. Be robust to unicode in the log messages.
|
299 |
+
with codecs.open(setup.working_dir+'/.comparison.log.txt', 'w', encoding='utf-8', errors='replace') as f: f.write(res[_LogKey_])
|
300 |
+
truncate_results_logs(res)
|
301 |
+
|
302 |
+
print( 'Comparison finished with output:\n{}'.format( res[_LogKey_] ) )
|
303 |
+
|
304 |
+
with open(setup.working_dir+'/.comparison.results.json', 'w') as f: json.dump(res, f, sort_keys=True, indent=2)
|
305 |
+
|
306 |
+
#print( 'Comparison finished with results:\n{}'.format( json.dumps(res, sort_keys=True, indent=2) ) )
|
307 |
+
if 'summary' in res: print('Summary section:\n{}'.format( json.dumps(res['summary'], sort_keys=True, indent=2) ) )
|
308 |
+
|
309 |
+
print( f'Output results of this comparison saved to {working_dir_1}/.comparison.results.json\nComparison log saved into {working_dir_1}/.comparison.log.txt' )
|
310 |
+
|
311 |
+
|
312 |
+
else:
|
313 |
+
working_dir = setup.working_dir #os.path.abspath( setup.config['results_root'] + f'/{platform["os"]}.{test}{options.suffix}' )
|
314 |
+
|
315 |
+
hpc_driver_name = setup.config['hpc_driver']
|
316 |
+
hpc_driver = None if hpc_driver_name in ['', 'none'] else eval(hpc_driver_name + '_HPC_Driver')(working_dir, setup.config, tracer=print, set_daemon_message=lambda x:None)
|
317 |
+
|
318 |
+
api_version = test_suite._api_version_ if hasattr(test_suite, '_api_version_') else ''
|
319 |
+
|
320 |
+
# if api_version < '1.0':
|
321 |
+
# res = test_suite.run(test=test_name, rosetta_dir=os.path.abspath('../..'), working_dir=working_dir, platform=dict(Platform), jobs=Config.cpu_count, verbose=True, debug=Options.debug)
|
322 |
+
# else:
|
323 |
+
|
324 |
+
if api_version == '1.0': res = test_suite.run(test=test_name, repository_root=os.path.abspath('./..'), working_dir=working_dir, platform=dict(setup.platform), config=setup.config, hpc_driver=hpc_driver, verbose=True, debug=setup.debug)
|
325 |
+
else:
|
326 |
+
print(f'Test benchmark api_version={api_version} is not supported!'); sys.exit(1)
|
327 |
+
|
328 |
+
if not isinstance(res, dict): print(f'Test returned result of type {type(res)} while dict-like object was expected, please check that test-script have correct `return` statment! Terminating...'); sys.exit(1)
|
329 |
+
|
330 |
+
# Caution! Some of the strings in the result object may be unicode. Be robust to unicode in the log messages
|
331 |
+
with codecs.open(working_dir+'/.execution.log.txt', 'w', encoding='utf-8', errors='replace') as f: f.write( res[_LogKey_] )
|
332 |
+
|
333 |
+
# res[_LogKey_] = truncate_log( res[_LogKey_] )
|
334 |
+
truncate_results_logs(res)
|
335 |
+
|
336 |
+
if _DescriptionKey_ not in res: res[_DescriptionKey_] = test_description
|
337 |
+
|
338 |
+
if res[_StateKey_] not in _S_Values_: print( 'Warning!!! Test {} failed with unknow result code: {}'.format(test_name, res[_StateKey_]) )
|
339 |
+
else: print( f'Test {test} finished with output:\n{res[_LogKey_]}\n----------------------------------------------------------------\nState: {res[_StateKey_]!r} | ', end='')
|
340 |
+
|
341 |
+
# JSON by default serializes to an ascii-encoded format
|
342 |
+
with open(working_dir+'/.execution.results.json', 'w') as f: json.dump(res, f, sort_keys=True, indent=2)
|
343 |
+
|
344 |
+
print( f'Output and full log of this test saved to:\n{working_dir}/.execution.results.json\n{working_dir}/.execution.log.txt' )
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
|
349 |
+
|
350 |
+
|
351 |
+
def main(args):
|
352 |
+
''' Script to Run arbitrary Rosetta test
|
353 |
+
'''
|
354 |
+
parser = argparse.ArgumentParser(usage="Main testing script to run tests in the tests directory. "
|
355 |
+
"Use the --skip-compile to skip the build phase when testing locally. "
|
356 |
+
"Example Command: /benchmark.py -j2 integration.valgrind")
|
357 |
+
|
358 |
+
parser.add_argument('-j', '--jobs', default=0, type=int, help="Number of processors to use on when building. (default: use value from config file or 1)")
|
359 |
+
|
360 |
+
parser.add_argument('-m', '--memory', default=0, type=int, help="Amount of memory to use (default: use 2Gb per job")
|
361 |
+
|
362 |
+
parser.add_argument('--compiler', default=Platform['compiler'], help="Compiler to use")
|
363 |
+
|
364 |
+
#parser.add_argument('--python', default=('3.9' if Platform['os'] == 'mac' else '3.6'), help="Python interpreter to use")
|
365 |
+
parser.add_argument('--python', default=f'{sys.version_info.major}.{sys.version_info.minor}.s', help="Specify version of Python interpreter to use, for example '3.9'. If '.s' added to end of version string then use the same interpreter that was used to start this script. Default: '?.?.s'")
|
366 |
+
|
367 |
+
parser.add_argument("--extras", default='', help="Specify scons extras separated by ',': like --extras=mpi,static" )
|
368 |
+
|
369 |
+
parser.add_argument("--debug", action="store_true", dest="debug", default=False, help="Run specified test in debug mode (not with debug build!) this mean different things and depend on the test. Could be: skip the build phase, skip some of the test phases and so on. [off by default]" )
|
370 |
+
|
371 |
+
parser.add_argument("--suffix", default='', help="Specify ending suffix for test output dir. This is useful when you want to save test results in different dir for later comparison." )
|
372 |
+
|
373 |
+
parser.add_argument("--compare", nargs=2, help="Do not run the tests but instead compare previous results. Use --compare suffix1 suffix2" )
|
374 |
+
|
375 |
+
parser.add_argument("--config", default='benchmark.{os}.ini'.format(os=Platform['os']), action="store", help="Location of .ini file with additional options configuration. Optional.")
|
376 |
+
|
377 |
+
parser.add_argument("--skip-compile", dest='skip_compile', default=None, action="store_true", help="Skip the compilation phase. Assumes the binaries are already compiled locally.")
|
378 |
+
|
379 |
+
#parser.add_argument("--results-root", default=None, action="store", help="Location of `results` dir default is to use `./results`")
|
380 |
+
|
381 |
+
parser.add_argument("--setup", default=None, help="Specify JSON file with setup information. When this option supplied all other config and commandline options is ignored and auto-detection disable. Test, platform info will be gathered from provided JSON file. This option is designed to be used in daemon mode." )
|
382 |
+
|
383 |
+
parser.add_argument("--merge-head", default='HEAD', help="Specify SHA1/branch-name that will be used for `merge-head` value when simulating PR testing" )
|
384 |
+
|
385 |
+
parser.add_argument("--merge-base", default='origin/master', help="Specify SHA1/branch-name that will be used for `merge-base` value when simulating PR testing" )
|
386 |
+
|
387 |
+
parser.add_argument("--mount", action="append", help="Specify one of the mount points, like: --mount release_root:/some/path. This option could be used multiple times if needed" )
|
388 |
+
|
389 |
+
|
390 |
+
parser.add_argument('args', nargs=argparse.REMAINDER)
|
391 |
+
|
392 |
+
options = parser.parse_args(args=args[1:])
|
393 |
+
|
394 |
+
if any( [a.startswith('-') for a in options.args] ) :
|
395 |
+
print( '\nWARNING WARNING WARNING WARNING\n' )
|
396 |
+
print( '\tInterpreting', ' '.join(["'"+a+"'" for a in options.args if a.startswith('-')]), 'as test name(s), rather than as option(s).' )
|
397 |
+
print( "\tTry moving it before any test name, if that's not what you want." )
|
398 |
+
print( '\nWARNING WARNING WARNING WARNING\n' )
|
399 |
+
|
400 |
+
|
401 |
+
if options.setup:
|
402 |
+
with open(options.setup) as f: setup = Setup( **json.load(f) )
|
403 |
+
|
404 |
+
else:
|
405 |
+
setup = setup_from_options(options)
|
406 |
+
|
407 |
+
run_test(setup)
|
408 |
+
|
409 |
+
|
410 |
+
if __name__ == "__main__": main(sys.argv)
|
.rosetta-ci/benchmark.template.ini
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Benchmark script configuration file. Some of the tests require some system specific options to run. Please see benchmark.ini.template for list of available options.
|
3 |
+
#
|
4 |
+
|
5 |
+
[DEFAULT]
|
6 |
+
|
7 |
+
[main] # additional config-options for various tests. All this fields will be pass as keys in 'config' function argument
|
8 |
+
|
9 |
+
# how many jobs daemon can run on host machine (this is not related to HPC jobs)
|
10 |
+
cpu_count = 24
|
11 |
+
|
12 |
+
# how many memory in GB daemon can use on host machine (approximation, float)
|
13 |
+
memory = 64
|
14 |
+
|
15 |
+
# user name and email for user who submitted this test
|
16 |
+
user_name = Jane Roe
|
17 |
+
user_email = jane.roe@university.edu
|
18 |
+
|
19 |
+
# HPC Driver, might have one of the following values: MultiCore, Condor, Slurm or none if no HPC Driver should be configured
|
20 |
+
hpc_driver = MultiCore
|
21 |
+
|
22 |
+
# when running by daemons branch:revision will be set to appropriate values to represent currently checked version of main repository
|
23 |
+
branch = unknown
|
24 |
+
revision = 42
|
25 |
+
|
26 |
+
# path to directory where test results will be stored
|
27 |
+
results_root = ${_here_}/results
|
28 |
+
|
29 |
+
release_root = ./results/_release_
|
30 |
+
|
31 |
+
[slurm]
|
32 |
+
# head-node host name, if specified will be used to submit jobs
|
33 |
+
head_node =
|
34 |
+
|
35 |
+
|
36 |
+
[mount]
|
37 |
+
# list of key:path pairs that will be avalible as config.mounts during test run
|
38 |
+
|
39 |
+
# path to releases, leave empty if release production should not be supported by this daemon
|
40 |
+
release_root = ${_here_}/release
|
.rosetta-ci/hpc_drivers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# :noTabs=true:
|
3 |
+
|
4 |
+
from .multicore import MultiCore_HPC_Driver
|
5 |
+
from .slurm import Slurm_HPC_Driver
|
.rosetta-ci/hpc_drivers/base.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# :noTabs=true:
|
3 |
+
|
4 |
+
import os, sys, subprocess, stat
|
5 |
+
import time as time_module
|
6 |
+
import signal as signal_module
|
7 |
+
|
8 |
+
class NT: # named tuple
|
9 |
+
def __init__(self, **entries): self.__dict__.update(entries)
|
10 |
+
def __repr__(self):
|
11 |
+
r = 'NT: |'
|
12 |
+
for i in dir(self):
|
13 |
+
if not i.startswith('__') and not isinstance(getattr(self, i), types.MethodType): r += '{} --> {}, '.format(i, getattr(self, i))
|
14 |
+
return r[:-2]+'|'
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
class HPC_Exception(Exception):
|
19 |
+
def __init__(self, value): self.value = value
|
20 |
+
def __str__(self): return self.value
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
def execute(message, command_line, return_='status', until_successes=False, terminate_on_failure=True, silent=False, silence_output=False, tracer=print):
|
25 |
+
if not silent: tracer(message); tracer(command_line); sys.stdout.flush();
|
26 |
+
while True:
|
27 |
+
|
28 |
+
p = subprocess.Popen(command_line, bufsize=0, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
29 |
+
output, errors = p.communicate()
|
30 |
+
|
31 |
+
output = output + errors
|
32 |
+
|
33 |
+
output = output.decode(encoding="utf-8", errors="replace")
|
34 |
+
|
35 |
+
exit_code = p.returncode
|
36 |
+
|
37 |
+
if exit_code and not (silent or silence_output): tracer(output); sys.stdout.flush();
|
38 |
+
|
39 |
+
if exit_code and until_successes: pass # Thats right - redability COUNT!
|
40 |
+
else: break
|
41 |
+
|
42 |
+
tracer( "Error while executing {}: {}\n".format(message, output) )
|
43 |
+
tracer("Sleeping 60s... then I will retry...")
|
44 |
+
sys.stdout.flush();
|
45 |
+
time.sleep(60)
|
46 |
+
|
47 |
+
if return_ == 'tuple': return(exit_code, output)
|
48 |
+
|
49 |
+
if exit_code and terminate_on_failure:
|
50 |
+
tracer("\nEncounter error while executing: " + command_line)
|
51 |
+
if return_==True: return True
|
52 |
+
else: print("\nEncounter error while executing: " + command_line + '\n' + output); sys.exit(1)
|
53 |
+
|
54 |
+
if return_ == 'output': return output
|
55 |
+
else: return False
|
56 |
+
|
57 |
+
|
58 |
+
def Sleep(time_, message, dict_={}):
|
59 |
+
''' Fancy sleep function '''
|
60 |
+
len_ = 0
|
61 |
+
for i in range(time_, 0, -1):
|
62 |
+
#print "Waiting for a new revision:%s... Sleeping...%d \r" % (sc.revision, i),
|
63 |
+
msg = message.format( **dict(dict_, time_left=i) )
|
64 |
+
print( msg, end='' )
|
65 |
+
len_ = max(len_, len(msg))
|
66 |
+
sys.stdout.flush()
|
67 |
+
time_module.sleep(1)
|
68 |
+
|
69 |
+
print( ' '*len_ + '\r', end='' ) # erazing sleep message
|
70 |
+
|
71 |
+
|
72 |
+
# Abstract class for HPC job submission
|
73 |
+
class HPC_Driver:
|
74 |
+
def __init__(self, working_dir, config, tracer=lambda x:None, set_daemon_message=lambda x:None):
|
75 |
+
self.working_dir = working_dir
|
76 |
+
self.config = config
|
77 |
+
self.cpu_usage = 0.0 # cummulative cpu usage in hours
|
78 |
+
self.tracer = tracer
|
79 |
+
self.set_daemon_message = set_daemon_message
|
80 |
+
|
81 |
+
self.cpu_count = self.config['cpu_count'] if type(config) == dict else self.config.getint('DEFAULT', 'cpu_count')
|
82 |
+
|
83 |
+
self.jobs = [] # list of all jobs currently running by this driver, Job class is driver depended, could be just int or something more complex
|
84 |
+
|
85 |
+
self.install_signal_handler()
|
86 |
+
|
87 |
+
|
88 |
+
def __del__(self):
|
89 |
+
self.remove_signal_handler()
|
90 |
+
|
91 |
+
|
92 |
+
def execute(self, executable, arguments, working_dir, log_dir=None, name='_no_name_', memory=256, time=24, shell_wrapper=False, block=True):
|
93 |
+
''' Execute given command line on HPC cluster, must accumulate cpu hours in self.cpu_usage '''
|
94 |
+
if log_dir==None: log_dir=self.working_dir
|
95 |
+
|
96 |
+
if shell_wrapper:
|
97 |
+
shell_wrapper_sh = os.path.abspath(self.working_dir + '/hpc.{}.shell_wrapper.sh'.format(name))
|
98 |
+
with file(shell_wrapper_sh, 'w') as f: f.write('#!/bin/bash\n{} {}\n'.format(executable, arguments)); os.fchmod(f.fileno(), stat.S_IEXEC | stat.S_IREAD | stat.S_IWRITE)
|
99 |
+
executable, arguments = shell_wrapper_sh, ''
|
100 |
+
|
101 |
+
return self.submit_serial_hpc_job(name=name, executable=executable, arguments=arguments, working_dir=working_dir, log_dir=log_dir, jobs_to_queue=1, memory=memory, time=time, block=block, shell_wrapper=shell_wrapper)
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
@property
|
106 |
+
def number_of_cpu_per_node(self):
|
107 |
+
must_be_implemented_in_inherited_classes
|
108 |
+
|
109 |
+
@property
|
110 |
+
def maximum_number_of_mpi_cpu(self):
|
111 |
+
must_be_implemented_in_inherited_classes
|
112 |
+
|
113 |
+
|
114 |
+
def submit_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
|
115 |
+
print('submit_hpc_job is DEPRECATED and will be removed in near future, please use submit_serial_hpc_job instead!')
|
116 |
+
must_be_implemented_in_inherited_classes
|
117 |
+
|
118 |
+
|
119 |
+
def submit_serial_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
|
120 |
+
must_be_implemented_in_inherited_classes
|
121 |
+
|
122 |
+
|
123 |
+
def submit_mpi_hpc_job(self, name, executable, arguments, working_dir, log_dir, memory=512, time=12, block=True, process_coefficient="1", requested_nodes=1, requested_processes_per_node=1):
|
124 |
+
''' submit jobs as MPI job
|
125 |
+
process_coefficient should be string representing fraction of process to launch on each node, for example '3 / 4' will start only 75% of MPI process's on each node
|
126 |
+
'''
|
127 |
+
must_be_implemented_in_inherited_classes
|
128 |
+
|
129 |
+
|
130 |
+
def cancel_all_jobs(self):
|
131 |
+
''' Cancel all HPC jobs known to this driver, use this as signal handler for script termination '''
|
132 |
+
for j in self.jobs: self.cancel_job(j)
|
133 |
+
|
134 |
+
def block_until(self, silent, fn, *args, **kwargs):
|
135 |
+
'''
|
136 |
+
**fn must have the driver as the first argument**
|
137 |
+
example:
|
138 |
+
def fn(driver):
|
139 |
+
jobs = list(driver.jobs)
|
140 |
+
jobs = [job for job in jobs if not driver.complete(job)]
|
141 |
+
if len(jobs) <= 8:
|
142 |
+
return False # stops sleeping
|
143 |
+
return True # continues sleeping
|
144 |
+
|
145 |
+
for x in range(100):
|
146 |
+
hpc_driver.submit_hpc_job(...)
|
147 |
+
hpc_driver.block_until(False, fn)
|
148 |
+
'''
|
149 |
+
while fn(self, *args, **kwargs):
|
150 |
+
sys.stdout.flush()
|
151 |
+
time_module.sleep(60)
|
152 |
+
if not silent:
|
153 |
+
Sleep(1, '"Waiting for HPC job(s) to finish, sleeping {time_left}s\r')
|
154 |
+
|
155 |
+
def wait_until_complete(self, jobs=None, callback=None, silent=False):
|
156 |
+
''' Helper function, wait until given jobs list is finished, if no argument is given waits until all jobs known by driver is finished '''
|
157 |
+
jobs = jobs if jobs else self.jobs
|
158 |
+
|
159 |
+
while jobs:
|
160 |
+
for j in jobs[:]:
|
161 |
+
if self.complete(j): jobs.remove(j)
|
162 |
+
|
163 |
+
if jobs:
|
164 |
+
#total_cpu_queued = sum( [j.jobs_queued for j in jobs] )
|
165 |
+
#total_cpu_running = sum( [j.cpu_running for j in jobs] )
|
166 |
+
#self.set_daemon_message("Waiting for HPC job(s) to finish... [{} process(es) in queue, {} process(es) running]".format(total_cpu_queued, total_cpu_running) )
|
167 |
+
#self.tracer("Waiting for HPC job(s) [{} process(es) in queue, {} process(es) running]... \r".format(total_cpu_queued, total_cpu_running), end='')
|
168 |
+
#print "Waiting for {} HPC jobs to finish... [{} jobs in queue, {} jobs running]... Sleeping 32s... \r".format(total_cpu_queued, cpu_queued+cpu_running, cpu_running),
|
169 |
+
|
170 |
+
self.set_daemon_message("Waiting for HPC {} job(s) to finish...".format( len(jobs) ) )
|
171 |
+
#self.tracer("Waiting for HPC {} job(s) to finish...".format( len(jobs) ) )
|
172 |
+
|
173 |
+
sys.stdout.flush()
|
174 |
+
|
175 |
+
if callback: callback()
|
176 |
+
|
177 |
+
if silent: time_module.sleep(64*1)
|
178 |
+
else: Sleep(64, '"Waiting for HPC {n_jobs} job(s) to finish, sleeping {time_left}s \r', dict(n_jobs=len(jobs)))
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
_signals_ = [signal_module.SIGINT, signal_module.SIGTERM, signal_module.SIGABRT]
|
183 |
+
def install_signal_handler(self):
|
184 |
+
def signal_handler(signal_, frame):
|
185 |
+
self.tracer('Recieved signal:{}... Canceling HPC jobs...'.format(signal_) )
|
186 |
+
self.cancel_all_jobs()
|
187 |
+
self.set_daemon_message( 'Remote daemon got terminated with signal:{}'.format(signal_) )
|
188 |
+
sys.exit(1)
|
189 |
+
|
190 |
+
for s in self._signals_: signal_module.signal(s, signal_handler)
|
191 |
+
|
192 |
+
|
193 |
+
def remove_signal_handler(self): # do we really need this???
|
194 |
+
try:
|
195 |
+
for s in self._signals_: signal_module.signal(s, signal_module.SIG_DFL)
|
196 |
+
#print('remove_signal_handler: done!')
|
197 |
+
|
198 |
+
except TypeError:
|
199 |
+
#print('remove_signal_handler: interpreted terminating, skipping remove_signal_handler...')
|
200 |
+
pass
|
201 |
+
|
202 |
+
|
203 |
+
def cancel_job(self, job_id):
|
204 |
+
must_be_implemented_in_inherited_classes
|
205 |
+
|
206 |
+
|
207 |
+
def complete(self, job_id):
|
208 |
+
''' Return job completion status. Return True if job complered and False otherwise
|
209 |
+
'''
|
210 |
+
must_be_implemented_in_inherited_classes
|
.rosetta-ci/hpc_drivers/multicore.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# :noTabs=true:
|
3 |
+
|
4 |
+
import time as time_module
|
5 |
+
import codecs
|
6 |
+
import signal
|
7 |
+
|
8 |
+
import os, sys
|
9 |
+
|
10 |
+
try:
|
11 |
+
from .base import *
|
12 |
+
|
13 |
+
except ImportError: # workaround for B2 back-end's
|
14 |
+
import imp
|
15 |
+
imp.load_source(__name__, '/'.join(__file__.split('/')[:-1]) + '/base.py') # A bit of Python magic here, what we trying to say is this: from base import *, but path to base is calculated from our source location # from base import HPC_Driver, execute, NT
|
16 |
+
|
17 |
+
|
18 |
+
class MultiCore_HPC_Driver(HPC_Driver):
|
19 |
+
|
20 |
+
class JobID:
|
21 |
+
def __init__(self, pids=None):
|
22 |
+
self.pids = pids if pids else []
|
23 |
+
|
24 |
+
|
25 |
+
def __bool__(self): return bool(self.pids)
|
26 |
+
|
27 |
+
|
28 |
+
def __len__(self): return len(self.pids)
|
29 |
+
|
30 |
+
|
31 |
+
def add_pid(self, pid): self.pids.append(pid)
|
32 |
+
|
33 |
+
|
34 |
+
def remove_completed_pids(self):
|
35 |
+
for pid in self.pids[:]:
|
36 |
+
try:
|
37 |
+
r = os.waitpid(pid, os.WNOHANG)
|
38 |
+
if r == (pid, 0): self.pids.remove(pid) # process have ended without error
|
39 |
+
elif r[0] == pid : # process ended but with error, special case we will have to wait for all process to terminate and call system exit.
|
40 |
+
#self.cancel_job()
|
41 |
+
#sys.exit(1)
|
42 |
+
self.pids.remove(pid)
|
43 |
+
print('ERROR: Some of the HPC jobs terminated abnormally! Please see HPC logs for details.')
|
44 |
+
|
45 |
+
except ChildProcessError: self.pids.remove(pid)
|
46 |
+
|
47 |
+
|
48 |
+
def cancel(self):
|
49 |
+
for pid in self.pids:
|
50 |
+
try:
|
51 |
+
os.killpg(os.getpgid(pid), signal.SIGKILL)
|
52 |
+
except ChildProcessError: pass
|
53 |
+
|
54 |
+
self.pids = []
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
def __init__(self, *args, **kwds):
|
59 |
+
HPC_Driver.__init__(self, *args, **kwds)
|
60 |
+
#print(f'MultiCore_HPC_Driver: cpu_count: {self.cpu_count}')
|
61 |
+
|
62 |
+
|
63 |
+
def remove_completed_jobs(self):
|
64 |
+
for job in self.jobs[:]: # Need to make a copy so we don't modify a list we're iterating over
|
65 |
+
job.remove_completed_pids()
|
66 |
+
if not job: self.jobs.remove(job)
|
67 |
+
|
68 |
+
|
69 |
+
@property
|
70 |
+
def process_count(self):
|
71 |
+
''' return number of processes that currently ran by this driver instance
|
72 |
+
'''
|
73 |
+
return sum( map(len, self.jobs) )
|
74 |
+
|
75 |
+
|
76 |
+
def submit_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
|
77 |
+
print('submit_hpc_job is DEPRECATED and will be removed in near future, please use submit_serial_hpc_job instead!')
|
78 |
+
return self.submit_serial_hpc_job(name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory, time, block, shell_wrapper)
|
79 |
+
|
80 |
+
|
81 |
+
def submit_serial_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
|
82 |
+
cpu_usage = -time_module.time()/60./60.
|
83 |
+
|
84 |
+
if shell_wrapper:
|
85 |
+
shell_wrapper_sh = os.path.abspath(self.working_dir + f'/hpc.{name}.shell_wrapper.sh')
|
86 |
+
with open(shell_wrapper_sh, 'w') as f: f.write('#!/bin/bash\n{} {}\n'.format(executable, arguments)); os.fchmod(f.fileno(), stat.S_IEXEC | stat.S_IREAD | stat.S_IWRITE)
|
87 |
+
executable, arguments = shell_wrapper_sh, ''
|
88 |
+
|
89 |
+
def mfork():
|
90 |
+
''' Check if number of child process is below cpu_count. And if it is - fork the new pocees and return its pid.
|
91 |
+
'''
|
92 |
+
while self.process_count >= self.cpu_count:
|
93 |
+
self.remove_completed_jobs()
|
94 |
+
if self.process_count >= self.cpu_count: time_module.sleep(.5)
|
95 |
+
|
96 |
+
sys.stdout.flush()
|
97 |
+
pid = os.fork()
|
98 |
+
# appending at caller level insted if pid: self.jobs.append(pid) # We are parent!
|
99 |
+
return pid
|
100 |
+
|
101 |
+
current_job = self.JobID()
|
102 |
+
process = 0
|
103 |
+
for i in range(jobs_to_queue):
|
104 |
+
|
105 |
+
pid = mfork()
|
106 |
+
if not pid: # we are child process
|
107 |
+
command_line = 'cd {} && {} {}'.format(working_dir, executable, arguments.format(process=process) )
|
108 |
+
exit_code, log = execute('Running job {}.{}...'.format(name, i), command_line, tracer=self.tracer, return_='tuple')
|
109 |
+
with codecs.open(log_dir+'/.hpc.{name}.{i:02d}.log'.format(**vars()), 'w', encoding='utf-8', errors='replace') as f:
|
110 |
+
f.write(command_line+'\n'+log)
|
111 |
+
if exit_code:
|
112 |
+
error_report = f'\n\n{command_line}\nERROR: PROCESS {name}.{i:02d} TERMINATED WITH NON-ZERO-EXIT-CODE {exit_code}!\n'
|
113 |
+
f.write(error_report)
|
114 |
+
print(log, error_report)
|
115 |
+
|
116 |
+
sys.exit(0)
|
117 |
+
|
118 |
+
else: # we are parent!
|
119 |
+
current_job.add_pid(pid)
|
120 |
+
# Need to potentially re-add to list, as remove_completed_jobs() might trim it.
|
121 |
+
if current_job not in self.jobs: self.jobs.append(current_job)
|
122 |
+
|
123 |
+
process += 1
|
124 |
+
|
125 |
+
if block:
|
126 |
+
#for p in all_queued_jobs: os.waitpid(p, 0) # waiting for all child process to termintate...
|
127 |
+
|
128 |
+
self.wait_until_complete(current_job)
|
129 |
+
self.remove_completed_jobs()
|
130 |
+
|
131 |
+
cpu_usage += time_module.time()/60./60.
|
132 |
+
self.cpu_usage += cpu_usage * jobs_to_queue # approximation...
|
133 |
+
|
134 |
+
current_job = self.JobID()
|
135 |
+
|
136 |
+
return current_job
|
137 |
+
|
138 |
+
|
139 |
+
@property
|
140 |
+
def number_of_cpu_per_node(self): return self.cpu_count
|
141 |
+
|
142 |
+
|
143 |
+
@property
|
144 |
+
def maximum_number_of_mpi_cpu(self): return self.cpu_count
|
145 |
+
|
146 |
+
|
147 |
+
def submit_mpi_hpc_job(self, name, executable, arguments, working_dir, log_dir, memory=512, time=12, block=True, process_coefficient="1", requested_nodes=1, requested_processes_per_node=1):
|
148 |
+
|
149 |
+
if requested_nodes > 1:
|
150 |
+
print( "WARNING: " + str( requested_nodes ) + " nodes were requested, but we're running locally, so only 1 node will be used." )
|
151 |
+
|
152 |
+
if requested_processes_per_node > self.cpu_count:
|
153 |
+
print( "WARNING: " + str(requested_processes_per_node) + " processes were requested, but I only have " + str(self.cpu_count) + " CPUs. Will launch " + str(self.cpu_count) + " processes." )
|
154 |
+
actual_processes = min( requested_processes_per_node, self.cpu_count )
|
155 |
+
|
156 |
+
cpu_usage = -time_module.time()/60./60.
|
157 |
+
|
158 |
+
arguments = arguments.format(process=0)
|
159 |
+
|
160 |
+
command_line = f'cd {working_dir} && mpirun -np {actual_processes} {executable} {arguments}'
|
161 |
+
log = execute(f'Running job {name}...', command_line, tracer=self.tracer, return_='output')
|
162 |
+
with codecs.open(log_dir+'/.hpc.{name}.log'.format(**vars()), 'w', encoding='utf-8', errors='replace') as f: f.write(command_line+'\n'+log)
|
163 |
+
|
164 |
+
cpu_usage += time_module.time()/60./60.
|
165 |
+
self.cpu_usage += cpu_usage * actual_processes # approximation...
|
166 |
+
|
167 |
+
# return None - we do not return anything from this version of submit which imply returning None which in turn will be treated as job-id for already finished job
|
168 |
+
|
169 |
+
|
170 |
+
def complete(self, job_id):
|
171 |
+
''' Return job completion status. Return True if job completed and False otherwise
|
172 |
+
'''
|
173 |
+
self.remove_completed_jobs()
|
174 |
+
return job_id not in self.jobs
|
175 |
+
|
176 |
+
|
177 |
+
def cancel_job(self, job):
|
178 |
+
job.cancel();
|
179 |
+
if job in self.jobs:
|
180 |
+
self.jobs.remove(job)
|
181 |
+
|
182 |
+
|
183 |
+
def __repr__(self):
|
184 |
+
return 'MultiCore_HPC_Driver<>'
|
.rosetta-ci/hpc_drivers/slurm.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# :noTabs=true:
|
3 |
+
|
4 |
+
import os, sys, time, collections, math
|
5 |
+
import stat as stat_module
|
6 |
+
|
7 |
+
|
8 |
+
try:
|
9 |
+
from .base import *
|
10 |
+
|
11 |
+
except ImportError: # workaround for B2 back-end's
|
12 |
+
import imp
|
13 |
+
imp.load_source(__name__, '/'.join(__file__.split('/')[:-1]) + '/base.py') # A bit of Python magic here, what we trying to say is this: from base import *, but path to base is calculated from our source location # from base import HPC_Driver, execute, NT
|
14 |
+
|
15 |
+
|
16 |
+
_T_slurm_array_job_template_ = '''\
|
17 |
+
#!/bin/bash
|
18 |
+
#
|
19 |
+
#SBATCH --job-name={name}
|
20 |
+
#SBATCH --output={log_dir}/.hpc.%x.%a.output
|
21 |
+
#
|
22 |
+
#SBATCH --time={time}:00
|
23 |
+
#SBATCH --mem-per-cpu={memory}M
|
24 |
+
#SBATCH --chdir={working_dir}
|
25 |
+
#
|
26 |
+
#SBATCH --array=1-{jobs_to_queue}
|
27 |
+
|
28 |
+
srun {executable} {arguments}
|
29 |
+
'''
|
30 |
+
|
31 |
+
_T_slurm_mpi_job_template_ = '''\
|
32 |
+
#!/bin/bash
|
33 |
+
#
|
34 |
+
#SBATCH --job-name={name}
|
35 |
+
#SBATCH --output={log_dir}/.hpc.%x.output
|
36 |
+
#
|
37 |
+
#SBATCH --time={time}:00
|
38 |
+
#SBATCH --mem-per-cpu={memory}M
|
39 |
+
#SBATCH --chdir={working_dir}
|
40 |
+
#
|
41 |
+
#SBATCH --ntasks={ntasks}
|
42 |
+
|
43 |
+
mpirun {executable} {arguments}
|
44 |
+
'''
|
45 |
+
|
46 |
+
class Slurm_HPC_Driver(HPC_Driver):
|
47 |
+
def head_node_execute(self, message, command_line, *args, **kwargs):
|
48 |
+
head_node = self.config['slurm'].get('head_node')
|
49 |
+
|
50 |
+
command_line, host = (f"ssh {head_node} cd `pwd` '&& {command_line}'", head_node) if head_node else (command_line, 'localhost')
|
51 |
+
return execute(f'Executiong on {host}: {message}' if message else '', command_line, *args, **kwargs)
|
52 |
+
|
53 |
+
|
54 |
+
# NodeGroup = collections.namedtuple('NodeGroup', 'nodes cores')
|
55 |
+
|
56 |
+
# @property
|
57 |
+
# def mpi_topology(self):
|
58 |
+
# ''' return list of NodeGroup's
|
59 |
+
# '''
|
60 |
+
# pass
|
61 |
+
|
62 |
+
|
63 |
+
# @property
|
64 |
+
# def number_of_cpu_per_node(self): return int( self.config['condor']['mpi_cpu_per_node'] )
|
65 |
+
|
66 |
+
# @property
|
67 |
+
# def maximum_number_of_mpi_cpu(self):
|
68 |
+
# return self.number_of_cpu_per_node * int( self.config['condor']['mpi_maximum_number_of_nodes'] )
|
69 |
+
|
70 |
+
|
71 |
+
# def complete(self, condor_job_id):
|
72 |
+
# ''' Return job completion status. Note that single hpc_job may contatin inner list of individual HPC jobs, True should be return if they all run in to completion.
|
73 |
+
# '''
|
74 |
+
|
75 |
+
# execute('Releasing condor jobs...', 'condor_release $USER', return_='tuple')
|
76 |
+
|
77 |
+
# s = execute('', 'condor_q $USER | grep $USER | grep {}'.format(condor_job_id), return_='output', terminate_on_failure=False).replace(' ', '').replace('\n', '')
|
78 |
+
# if s: return False
|
79 |
+
|
80 |
+
# # #setDaemonStatusAndPing('[Job #%s] Running... %s condor job(s) in queue...' % (self.id, len(s.split('\n') ) ) )
|
81 |
+
# # n_jobs = len(s.split('\n'))
|
82 |
+
# # s, o = execute('', 'condor_userprio -all | grep $USER@', return_='tuple')
|
83 |
+
# # if s == 0:
|
84 |
+
# # jobs_running = o.split()
|
85 |
+
# # jobs_running = 'XX' if len(jobs_running) < 4 else jobs_running[4]
|
86 |
+
# # self.set_daemon_message("Waiting for condor to finish HPC jobs... [{} jobs in HPC-Queue, {} CPU's used]".format(n_jobs, jobs_running) )
|
87 |
+
# # print "{} condor jobs in queue... Sleeping 32s... \r".format(n_jobs),
|
88 |
+
# # sys.stdout.flush()
|
89 |
+
# # time.sleep(32)
|
90 |
+
# else:
|
91 |
+
|
92 |
+
# #self.tracer('Waiting for condor to finish the jobs... DONE')
|
93 |
+
# self.jobs.remove(condor_job_id)
|
94 |
+
# self.cpu_usage += self.get_condor_accumulated_usage()
|
95 |
+
# return True # jobs already finished, we return empty list to prevent double counting of cpu_usage
|
96 |
+
|
97 |
+
|
98 |
+
def complete(self, slurm_job_id):
|
99 |
+
''' Return True if job with given id is complete
|
100 |
+
'''
|
101 |
+
|
102 |
+
s = self.head_node_execute('', f'squeue -j {slurm_job_id} --noheader', return_='output', terminate_on_failure=False, silent=True)
|
103 |
+
if s: return False
|
104 |
+
else:
|
105 |
+
#self.tracer('Waiting for condor to finish the jobs... DONE')
|
106 |
+
self.jobs.remove(slurm_job_id)
|
107 |
+
return True # jobs already finished, we return empty list to prevent double counting of cpu_usage
|
108 |
+
|
109 |
+
|
110 |
+
def cancel_job(self, slurm_job_id):
|
111 |
+
self.head_node_execute(f'Slurm_HPC_Driver.canceling job {slurm_job_id}...', f'scancel {slurm_job_id}', terminate_on_failure=False)
|
112 |
+
|
113 |
+
|
114 |
+
# def submit_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
|
115 |
+
# print('submit_hpc_job is DEPRECATED and will be removed in near future, please use submit_serial_hpc_job instead!')
|
116 |
+
# return self.submit_serial_hpc_job(name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory, time, block, shell_wrapper)
|
117 |
+
|
118 |
+
|
119 |
+
def submit_serial_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
|
120 |
+
|
121 |
+
arguments = arguments.format(process='%a') # %a is SLURM array index
|
122 |
+
time = int( math.ceil(time*60) )
|
123 |
+
|
124 |
+
if shell_wrapper:
|
125 |
+
shell_wrapper_sh = os.path.abspath(self.working_dir + f'/hpc.{name}.shell_wrapper.sh')
|
126 |
+
with open(shell_wrapper_sh, 'w') as f: f.write('#!/bin/bash\n{} {}\n'.format(executable, arguments)); os.fchmod(f.fileno(), stat.S_IEXEC | stat.S_IREAD | stat.S_IWRITE)
|
127 |
+
executable, arguments = shell_wrapper_sh, ''
|
128 |
+
|
129 |
+
slurm_file = working_dir + f'/.hpc.{name}.slurm'
|
130 |
+
|
131 |
+
with open(slurm_file, 'w') as f: f.write( _T_slurm_array_job_template_.format( **vars() ) )
|
132 |
+
|
133 |
+
|
134 |
+
slurm_job_id = self.head_node_execute('Submitting SLURM array job...', f'cd {self.working_dir} && sbatch {slurm_file}',
|
135 |
+
tracer=self.tracer, return_='output'
|
136 |
+
).split()[-1] # expecting something like `Submitted batch job 6122` in output
|
137 |
+
|
138 |
+
|
139 |
+
self.jobs.append(slurm_job_id)
|
140 |
+
|
141 |
+
if block:
|
142 |
+
self.wait_until_complete( [slurm_job_id] )
|
143 |
+
return None
|
144 |
+
|
145 |
+
else: return slurm_job_id
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
def submit_mpi_hpc_job(self, name, executable, arguments, working_dir, log_dir, ntasks, memory=512, time=12, block=True, shell_wrapper=False):
|
152 |
+
''' submit jobs as MPI job
|
153 |
+
'''
|
154 |
+
arguments = arguments.format(process='0')
|
155 |
+
time = int( math.ceil(time*60) )
|
156 |
+
|
157 |
+
if shell_wrapper:
|
158 |
+
shell_wrapper_sh = os.path.abspath(self.working_dir + f'/hpc.{name}.shell_wrapper.sh')
|
159 |
+
with open(shell_wrapper_sh, 'w') as f: f.write('#!/bin/bash\n{} {}\n'.format(executable, arguments)); os.fchmod(f.fileno(), stat.S_IEXEC | stat.S_IREAD | stat.S_IWRITE)
|
160 |
+
executable, arguments = shell_wrapper_sh, ''
|
161 |
+
|
162 |
+
slurm_file = working_dir + f'/.hpc.{name}.slurm'
|
163 |
+
|
164 |
+
with open(slurm_file, 'w') as f: f.write( _T_slurm_mpi_job_template_.format( **vars() ) )
|
165 |
+
|
166 |
+
slurm_job_id = self.head_node_execute('Submitting SLURM mpi job...', f'cd {self.working_dir} && sbatch {slurm_file}',
|
167 |
+
tracer=self.tracer, return_='output'
|
168 |
+
).split()[-1] # expecting something like `Submitted batch job 6122` in output
|
169 |
+
|
170 |
+
self.jobs.append(slurm_job_id)
|
171 |
+
|
172 |
+
if block:
|
173 |
+
self.wait_until_complete( [slurm_job_id] )
|
174 |
+
return None
|
175 |
+
|
176 |
+
else: return slurm_job_id
|
.rosetta-ci/test-sets.yaml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# map platform-string → platform definiton
|
2 |
+
platforms:
|
3 |
+
ubuntu-20.04.gcc:
|
4 |
+
os: ubuntu-20.04
|
5 |
+
compiler: gcc
|
6 |
+
python: '3.9'
|
7 |
+
|
8 |
+
ubuntu-20.04.clang:
|
9 |
+
os: ubuntu-20.04
|
10 |
+
compiler: clang
|
11 |
+
python: '3.9'
|
12 |
+
|
13 |
+
|
14 |
+
# map of test-set-name → tests
|
15 |
+
test-sets:
|
16 |
+
main:
|
17 |
+
- ubuntu-20.04.clang.rfd
|
18 |
+
|
19 |
+
python:
|
20 |
+
- ubuntu-20.04.gcc.self.python
|
21 |
+
- ubuntu-20.04.clang.self.python
|
22 |
+
|
23 |
+
self:
|
24 |
+
- ubuntu-20.04.gcc.self.state
|
25 |
+
- ubuntu-20.04.gcc.self.subtests
|
26 |
+
- ubuntu-20.04.gcc.self.release
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
# map of GitHub-label → [test-set]
|
31 |
+
github-label-test-sets:
|
32 |
+
00 main: [main]
|
33 |
+
10 self: [self]
|
34 |
+
16 python: [python]
|
35 |
+
|
36 |
+
|
37 |
+
# map of submit-page-category → tests
|
38 |
+
# tests that does not get assigned will be automatically displayed in 'other' category
|
39 |
+
category-tests:
|
40 |
+
main:
|
41 |
+
- rfd
|
42 |
+
|
43 |
+
self:
|
44 |
+
- self.state
|
45 |
+
- self.subtests
|
46 |
+
- self.release
|
47 |
+
- self.python
|
48 |
+
|
49 |
+
|
50 |
+
# map branch → test-set to
|
51 |
+
# specify list of tests that should be applied by-default during testing of each new commits to specific branch
|
52 |
+
branch-test-sets:
|
53 |
+
main: [main]
|
54 |
+
benchmark: [main, python]
|
55 |
+
|
56 |
+
|
57 |
+
# map branch → test-sets for pull-request's
|
58 |
+
# specify which test-sets should be scheduled for PR's by-default (ie in addition to GH labels applied)
|
59 |
+
# use empty branch name to specify defult value for (ie any branch not explicitly listed)
|
60 |
+
pull-request-branch-test-sets:
|
61 |
+
# specific test sets for benchmark branch
|
62 |
+
benchmark: ['main', 'python']
|
63 |
+
|
64 |
+
# default, will apply to PR's to any other branch
|
65 |
+
'': ['main']
|
.rosetta-ci/tests/__init__.py
ADDED
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# :noTabs=true:
|
4 |
+
|
5 |
+
# (c) Copyright Rosetta Commons Member Institutions.
|
6 |
+
# (c) This file is part of the Rosetta software suite and is made available under license.
|
7 |
+
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
|
8 |
+
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
|
9 |
+
# (c) addressed to University of Washington CoMotion, email: license@uw.edu.
|
10 |
+
|
11 |
+
## @file tests/__init__.py
|
12 |
+
## @brief Common constats and types for all test types
|
13 |
+
## @author Sergey Lyskov
|
14 |
+
|
15 |
+
import os, time, sys, shutil, codecs, urllib.request, imp, subprocess, json, hashlib # urllib.error, urllib.parse,
|
16 |
+
import platform as platform_module
|
17 |
+
import types as types_module
|
18 |
+
|
19 |
+
# ⚔ do not change wording below, it have to stay in sync with upstream (up to benchmark-model).
|
20 |
+
# Copied from benchmark-model, standard state code's for tests results.
|
21 |
+
|
22 |
+
__all__ = ['execute',
|
23 |
+
'_S_Values_', '_S_draft_', '_S_queued_', '_S_running_', '_S_passed_', '_S_failed_', '_S_build_failed_', '_S_script_failed_',
|
24 |
+
'_StateKey_', '_ResultsKey_', '_LogKey_', '_DescriptionKey_', '_TestsKey_',
|
25 |
+
'_multi_step_config_', '_multi_step_error_', '_multi_step_result_',
|
26 |
+
'to_bytes',
|
27 |
+
]
|
28 |
+
|
29 |
+
_S_draft_ = 'draft'
|
30 |
+
_S_queued_ = 'queued'
|
31 |
+
_S_running_ = 'running'
|
32 |
+
_S_passed_ = 'passed'
|
33 |
+
_S_failed_ = 'failed'
|
34 |
+
_S_build_failed_ = 'build failed'
|
35 |
+
_S_script_failed_ = 'script failed'
|
36 |
+
_S_queued_for_comparison_ = 'queued for comparison'
|
37 |
+
|
38 |
+
_S_Values_ = [_S_draft_, _S_queued_, _S_running_, _S_passed_, _S_failed_, _S_build_failed_, _S_script_failed_, _S_queued_for_comparison_]
|
39 |
+
|
40 |
+
_IgnoreKey_ = 'ignore'
|
41 |
+
_StateKey_ = 'state'
|
42 |
+
_ResultsKey_ = 'results'
|
43 |
+
_LogKey_ = 'log'
|
44 |
+
_DescriptionKey_ = 'description'
|
45 |
+
_TestsKey_ = 'tests'
|
46 |
+
_SummaryKey_ = 'summary'
|
47 |
+
_FailedKey_ = 'failed'
|
48 |
+
_TotalKey_ = 'total'
|
49 |
+
_PlotsKey_ = 'plots'
|
50 |
+
_FailedTestsKey_ = 'failed_tests'
|
51 |
+
_HtmlKey_ = 'html'
|
52 |
+
|
53 |
+
# file names for multi-step test files
|
54 |
+
_multi_step_config_ = 'config.json'
|
55 |
+
_multi_step_error_ = 'error.json'
|
56 |
+
_multi_step_result_ = 'result.json'
|
57 |
+
|
58 |
+
PyRosetta_unix_memory_requirement_per_cpu = 6 # Memory per sub-process in Gb's
|
59 |
+
PyRosetta_unix_unit_test_memory_requirement_per_cpu = 3.0 # Memory per sub-process in Gb's for running PyRosetta unit tests
|
60 |
+
|
61 |
+
# Commands to run all the scripts needed for setting up Rosetta compiles. (Run from main/source directory)
|
62 |
+
PRE_COMPILE_SETUP_SCRIPTS = [ "./update_options.sh", "./update_submodules.sh", "./update_ResidueType_enum_files.sh", "python version.py" ]
|
63 |
+
|
64 |
+
DEFAULT_PYTHON_VERSION='3.9'
|
65 |
+
|
66 |
+
# Standard funtions and classes below ---------------------------------------------------------------------------------
|
67 |
+
|
68 |
+
class BenchmarkError(Exception):
|
69 |
+
def __init__(self, value): self.value = value
|
70 |
+
def __repr__(self): return self.value
|
71 |
+
def __str__(self): return self.value
|
72 |
+
|
73 |
+
|
74 |
+
class NT: # named tuple
|
75 |
+
def __init__(self, **entries): self.__dict__.update(entries)
|
76 |
+
def __repr__(self):
|
77 |
+
r = 'NT: |'
|
78 |
+
for i in dir(self):
|
79 |
+
print(i)
|
80 |
+
if not i.startswith('__') and i != '_as_dict' and not isinstance(getattr(self, i), types_module.MethodType): r += '%s --> %s, ' % (i, getattr(self, i))
|
81 |
+
return r[:-2]+'|'
|
82 |
+
|
83 |
+
@property
|
84 |
+
def _as_dict(self):
|
85 |
+
return { a: getattr(self, a) for a in dir(self) if not a.startswith('__') and a != '_as_dict' and not isinstance(getattr(self, a), types_module.MethodType)}
|
86 |
+
|
87 |
+
|
88 |
+
def Tracer(verbose=False):
|
89 |
+
return print if verbose else lambda x: None
|
90 |
+
|
91 |
+
|
92 |
+
def to_unicode(b):
|
93 |
+
''' Conver bytes to string and handle the errors. If argument is already in string - do nothing
|
94 |
+
'''
|
95 |
+
#return b if type(b) == unicode else unicode(b, 'utf-8', errors='replace')
|
96 |
+
return b if type(b) == str else str(b, 'utf-8', errors='backslashreplace')
|
97 |
+
|
98 |
+
|
99 |
+
def to_bytes(u):
|
100 |
+
''' Conver string to bytes and handle the errors. If argument is already of type bytes - do nothing
|
101 |
+
'''
|
102 |
+
return u if type(u) == bytes else u.encode('utf-8', errors='backslashreplace')
|
103 |
+
|
104 |
+
|
105 |
+
''' Python-2 version
|
106 |
+
def execute(message, commandline, return_=False, until_successes=False, terminate_on_failure=True, add_message_and_command_line_to_output=False):
|
107 |
+
message, commandline = to_unicode(message), to_unicode(commandline)
|
108 |
+
|
109 |
+
TR = Tracer()
|
110 |
+
TR(message); TR(commandline)
|
111 |
+
while True:
|
112 |
+
(res, output) = commands.getstatusoutput(commandline)
|
113 |
+
# Subprocess results will always be a bytes-string.
|
114 |
+
# Probably ASCII, but may have some Unicode characters.
|
115 |
+
# A UTF-8 decode will probably get decent results 99% of the time
|
116 |
+
# and the replace option will gracefully handle the rest.
|
117 |
+
output = to_unicode(output)
|
118 |
+
|
119 |
+
TR(output)
|
120 |
+
|
121 |
+
if res and until_successes: pass # Thats right - redability COUNT!
|
122 |
+
else: break
|
123 |
+
|
124 |
+
print( "Error while executing %s: %s\n" % (message, output) )
|
125 |
+
print( "Sleeping 60s... then I will retry..." )
|
126 |
+
time.sleep(60)
|
127 |
+
|
128 |
+
if add_message_and_command_line_to_output: output = message + '\nCommand line: ' + commandline + '\n' + output
|
129 |
+
|
130 |
+
if return_ == 'tuple': return(res, output)
|
131 |
+
|
132 |
+
if res and terminate_on_failure:
|
133 |
+
TR("\nEncounter error while executing: " + commandline)
|
134 |
+
if return_==True: return res
|
135 |
+
else:
|
136 |
+
print("\nEncounter error while executing: " + commandline + '\n' + output)
|
137 |
+
raise BenchmarkError("\nEncounter error while executing: " + commandline + '\n' + output)
|
138 |
+
|
139 |
+
if return_ == 'output': return output
|
140 |
+
else: return res
|
141 |
+
'''
|
142 |
+
|
143 |
+
def execute_through_subprocess(command_line):
|
144 |
+
# exit_code, output = subprocess.getstatusoutput(command_line)
|
145 |
+
|
146 |
+
# p = subprocess.Popen(command_line, bufsize=0, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
147 |
+
# output, errors = p.communicate()
|
148 |
+
# output = (output + errors).decode(encoding='utf-8', errors='backslashreplace')
|
149 |
+
# exit_code = p.returncode
|
150 |
+
|
151 |
+
# previous 'main' version based on subprocess module. Main issue that output of segfaults will not be captured since they generated by shell
|
152 |
+
p = subprocess.Popen(command_line, bufsize=0, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
153 |
+
output, errors = p.communicate()
|
154 |
+
# output = output + errors # ← we redirected stderr into same pipe as stdcout so errors is None, - no need to concatenate
|
155 |
+
output = output.decode(encoding='utf-8', errors='backslashreplace')
|
156 |
+
exit_code = p.returncode
|
157 |
+
|
158 |
+
return exit_code, output
|
159 |
+
|
160 |
+
|
161 |
+
def execute_through_pexpect(command_line):
|
162 |
+
import pexpect
|
163 |
+
|
164 |
+
child = pexpect.spawn('/bin/bash', ['-c', command_line])
|
165 |
+
child.expect(pexpect.EOF)
|
166 |
+
output = child.before.decode(encoding='utf-8', errors='backslashreplace')
|
167 |
+
child.close()
|
168 |
+
exit_code = child.signalstatus or child.exitstatus
|
169 |
+
|
170 |
+
return exit_code, output
|
171 |
+
|
172 |
+
|
173 |
+
def execute_through_pty(command_line):
|
174 |
+
import pty, select
|
175 |
+
|
176 |
+
if sys.platform == "darwin":
|
177 |
+
|
178 |
+
master, slave = pty.openpty()
|
179 |
+
p = subprocess.Popen(command_line, shell=True, stdout=slave, stdin=slave,
|
180 |
+
stderr=subprocess.STDOUT, close_fds=True)
|
181 |
+
|
182 |
+
buffer = []
|
183 |
+
while True:
|
184 |
+
try:
|
185 |
+
if select.select([master], [], [], 0.2)[0]: # has something to read
|
186 |
+
data = os.read(master, 1 << 22)
|
187 |
+
if data: buffer.append(data)
|
188 |
+
|
189 |
+
elif (p.poll() is not None) and (not select.select([master], [], [], 0.2)[0] ): break # process is finished and output buffer if fully read
|
190 |
+
|
191 |
+
except OSError: break # OSError will be raised when child process close PTY descriptior
|
192 |
+
|
193 |
+
output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace')
|
194 |
+
|
195 |
+
os.close(master)
|
196 |
+
os.close(slave)
|
197 |
+
|
198 |
+
p.wait()
|
199 |
+
exit_code = p.returncode
|
200 |
+
|
201 |
+
'''
|
202 |
+
buffer = []
|
203 |
+
while True:
|
204 |
+
if select.select([master], [], [], 0.2)[0]: # has something to read
|
205 |
+
data = os.read(master, 1 << 22)
|
206 |
+
if data: buffer.append(data)
|
207 |
+
# else: break # # EOF - well, technically process _should_ be finished here...
|
208 |
+
|
209 |
+
# elif time.sleep(1) or (p.poll() is not None): # process is finished (sleep here is intentional to trigger race condition, see solution for this on the next few lines)
|
210 |
+
# assert not select.select([master], [], [], 0.2)[0] # should be nothing left to read...
|
211 |
+
# break
|
212 |
+
|
213 |
+
elif (p.poll() is not None) and (not select.select([master], [], [], 0.2)[0] ): break # process is finished and output buffer if fully read
|
214 |
+
|
215 |
+
assert not select.select([master], [], [], 0.2)[0] # should be nothing left to read...
|
216 |
+
|
217 |
+
os.close(slave)
|
218 |
+
os.close(master)
|
219 |
+
|
220 |
+
output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace')
|
221 |
+
exit_code = p.returncode
|
222 |
+
'''
|
223 |
+
|
224 |
+
else:
|
225 |
+
|
226 |
+
master, slave = pty.openpty()
|
227 |
+
p = subprocess.Popen(command_line, shell=True, stdout=slave, stdin=slave,
|
228 |
+
stderr=subprocess.STDOUT, close_fds=True)
|
229 |
+
|
230 |
+
os.close(slave)
|
231 |
+
|
232 |
+
buffer = []
|
233 |
+
while True:
|
234 |
+
try:
|
235 |
+
data = os.read(master, 1 << 22)
|
236 |
+
if data: buffer.append(data)
|
237 |
+
except OSError: break # OSError will be raised when child process close PTY descriptior
|
238 |
+
|
239 |
+
output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace')
|
240 |
+
|
241 |
+
os.close(master)
|
242 |
+
|
243 |
+
p.wait()
|
244 |
+
exit_code = p.returncode
|
245 |
+
|
246 |
+
return exit_code, output
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
def execute(message, command_line, return_='status', until_successes=False, terminate_on_failure=True, silent=False, silence_output=False, silence_output_on_errors=False, add_message_and_command_line_to_output=False):
|
251 |
+
if not silent: print(message); print(command_line); sys.stdout.flush();
|
252 |
+
while True:
|
253 |
+
|
254 |
+
#exit_code, output = execute_through_subprocess(command_line)
|
255 |
+
#exit_code, output = execute_through_pexpect(command_line)
|
256 |
+
exit_code, output = execute_through_pty(command_line)
|
257 |
+
|
258 |
+
if (exit_code and not silence_output_on_errors) or not (silent or silence_output): print(output); sys.stdout.flush();
|
259 |
+
|
260 |
+
if exit_code and until_successes: pass # Thats right - redability COUNT!
|
261 |
+
else: break
|
262 |
+
|
263 |
+
print( "Error while executing {}: {}\n".format(message, output) )
|
264 |
+
print("Sleeping 60s... then I will retry...")
|
265 |
+
sys.stdout.flush();
|
266 |
+
time.sleep(60)
|
267 |
+
|
268 |
+
if add_message_and_command_line_to_output: output = message + '\nCommand line: ' + command_line + '\n' + output
|
269 |
+
|
270 |
+
if return_ == 'tuple' or return_ == tuple: return(exit_code, output)
|
271 |
+
|
272 |
+
if exit_code and terminate_on_failure:
|
273 |
+
print("\nEncounter error while executing: " + command_line)
|
274 |
+
if return_==True: return True
|
275 |
+
else:
|
276 |
+
print('\nEncounter error while executing: ' + command_line + '\n' + output);
|
277 |
+
raise BenchmarkError('\nEncounter error while executing: ' + command_line + '\n' + output)
|
278 |
+
|
279 |
+
if return_ == 'output': return output
|
280 |
+
else: return exit_code
|
281 |
+
|
282 |
+
|
283 |
+
def parallel_execute(name, jobs, rosetta_dir, working_dir, cpu_count, time=16):
|
284 |
+
''' Execute command line in parallel on local host
|
285 |
+
time specifies the upper limit for cpu-usage runtime (in minutes) for any one process in the parallel execution.
|
286 |
+
|
287 |
+
jobs should be dict with following structure:
|
288 |
+
{
|
289 |
+
'job-string-id-1’: command_line-1,
|
290 |
+
'job-string-id-2’: command_line-2,
|
291 |
+
...
|
292 |
+
}
|
293 |
+
|
294 |
+
return: dict with jobs-id's as keys and value as dict with 'output' and 'result' keys:
|
295 |
+
{
|
296 |
+
"job-string-id-1": {
|
297 |
+
"output": "stdout + stdderr output of command_line-1",
|
298 |
+
"result": <integer exit code for command_line-1>
|
299 |
+
},
|
300 |
+
"c2": {
|
301 |
+
"output": "stdout + stdderr output of command_line-2",
|
302 |
+
"result": <integer exit code for command_line-2>
|
303 |
+
},
|
304 |
+
...
|
305 |
+
}
|
306 |
+
'''
|
307 |
+
job_file_name = working_dir + '/' + name
|
308 |
+
with open(job_file_name + '.json', 'w') as f: json.dump(jobs, f, sort_keys=True, indent=2) # JSON handles unicode internally
|
309 |
+
if time is not None:
|
310 |
+
allowed_time = int(time*60)
|
311 |
+
ulimit_command = f'ulimit -t {allowed_time} && '
|
312 |
+
else:
|
313 |
+
ulimit_command = ''
|
314 |
+
command = f'cd {working_dir} && ' + ulimit_command + f'{rosetta_dir}/tests/benchmark/util/parallel.py -j{cpu_count} {job_file_name}.json'
|
315 |
+
execute("Running {} in parallel with {} CPU's...".format(name, cpu_count), command )
|
316 |
+
|
317 |
+
with open(job_file_name+'.results.json') as f: return json.load(f)
|
318 |
+
|
319 |
+
|
320 |
+
def calculate_unique_prefix_path(platform, config):
|
321 |
+
''' calculate path for prefix location that is unique for this machine and OS
|
322 |
+
'''
|
323 |
+
hostname = os.uname()[1]
|
324 |
+
return config['prefix'] + '/' + hostname + '/' + platform['os']
|
325 |
+
|
326 |
+
|
327 |
+
def get_python_include_and_lib(python):
|
328 |
+
''' calculate python include dir and lib dir from given python executable path
|
329 |
+
'''
|
330 |
+
#python = os.path.realpath(python)
|
331 |
+
python_bin_dir = python.rpartition('/')[0]
|
332 |
+
python_config = f'{python} {python}-config' if python.endswith('2.7') else f'{python}-config'
|
333 |
+
|
334 |
+
#if not os.path.isfile(python_config): python_config = python_bin_dir + '/python-config'
|
335 |
+
|
336 |
+
info = execute('Getting python configuration info...', f'unset __PYVENV_LAUNCHER__ && cd {python_bin_dir} && PATH=.:$PATH && {python_config} --prefix --includes', return_='output').replace('\r', '').split('\n') # Python-3 only: --abiflags
|
337 |
+
python_prefix = info[0]
|
338 |
+
python_include_dir = info[1].split()[0][len('-I'):]
|
339 |
+
python_lib_dir = python_prefix + '/lib'
|
340 |
+
#python_abi_suffix = info[2]
|
341 |
+
#print(python_include_dir, python_lib_dir)
|
342 |
+
|
343 |
+
return NT(python_include_dir=python_include_dir, python_lib_dir=python_lib_dir)
|
344 |
+
|
345 |
+
|
346 |
+
def local_open_ssl_install(prefix, build_prefix, jobs):
|
347 |
+
''' install OpenSSL at given prefix, return url of source archive
|
348 |
+
'''
|
349 |
+
#with tempfile.TemporaryDirectory('open_ssl_build', dir=prefix) as build_prefix:
|
350 |
+
|
351 |
+
url = 'https://www.openssl.org/source/openssl-1.1.1b.tar.gz'
|
352 |
+
#url = 'https://www.openssl.org/source/openssl-3.0.0.tar.gz'
|
353 |
+
|
354 |
+
|
355 |
+
archive = build_prefix + '/' + url.split('/')[-1]
|
356 |
+
build_dir = archive.rpartition('.tar.gz')[0]
|
357 |
+
if os.path.isdir(build_dir): shutil.rmtree(build_dir)
|
358 |
+
|
359 |
+
with open(archive, 'wb') as f:
|
360 |
+
response = urllib.request.urlopen(url)
|
361 |
+
f.write( response.read() )
|
362 |
+
|
363 |
+
execute('Unpacking {}'.format(archive), 'cd {build_prefix} && tar -xvzf {archive}'.format(**vars()) )
|
364 |
+
|
365 |
+
execute('Configuring...', f'cd {build_dir} && ./config --prefix={prefix}')
|
366 |
+
execute('Building...', f'cd {build_dir} && make -j{jobs}')
|
367 |
+
execute('Installing...', f'cd {build_dir} && make -j{jobs} install')
|
368 |
+
|
369 |
+
return url
|
370 |
+
|
371 |
+
|
372 |
+
def remove_pip_and_easy_install(prefix_root_path):
|
373 |
+
''' remove `pip` and `easy_install` executable from given Python / virtual-environments install
|
374 |
+
'''
|
375 |
+
for f in os.listdir(prefix_root_path + '/bin'): # removing all pip's and easy_install's to make sure that environment is immutable
|
376 |
+
for p in ['pip', 'easy_install']:
|
377 |
+
if f.startswith(p): os.remove(prefix_root_path + '/bin/' + f)
|
378 |
+
|
379 |
+
|
380 |
+
|
381 |
+
def local_python_install(platform, config):
|
382 |
+
''' Perform local install of given Python version and return path-to-python-interpreter, python_include_dir, python_lib_dir
|
383 |
+
If previous install is detected skip installiation.
|
384 |
+
Provided Python install will _persistent_ and _immutable_
|
385 |
+
'''
|
386 |
+
jobs = config['cpu_count']
|
387 |
+
compiler, cpp_compiler = ('clang', 'clang++') if platform['os'] == 'mac' else ('gcc', 'g++') # disregarding platform compiler setting and instead use default compiler for platform
|
388 |
+
|
389 |
+
python_version = platform.get('python', DEFAULT_PYTHON_VERSION)
|
390 |
+
|
391 |
+
if python_version.endswith('.s'):
|
392 |
+
assert python_version == f'{sys.version_info.major}.{sys.version_info.minor}.s'
|
393 |
+
#root = executable.rpartition('/bin/python')[0]
|
394 |
+
h = hashlib.md5(); h.update( (sys.executable + sys.version).encode('utf-8', errors='backslashreplace') ); hash = h.hexdigest()
|
395 |
+
return NT(
|
396 |
+
python = sys.executable,
|
397 |
+
root = None,
|
398 |
+
python_include_dir = None,
|
399 |
+
python_lib_dir = None,
|
400 |
+
version = python_version,
|
401 |
+
url = None,
|
402 |
+
platform = platform,
|
403 |
+
config = config,
|
404 |
+
hash = hash,
|
405 |
+
)
|
406 |
+
|
407 |
+
# deprecated, no longer needed
|
408 |
+
# python_version = {'python2' : '2.7',
|
409 |
+
# 'python2.7' : '2.7',
|
410 |
+
# 'python3' : '3.5',
|
411 |
+
# }.get(python_version, python_version)
|
412 |
+
|
413 |
+
# for security reasons we only allow installs for version listed here with hand-coded URL's
|
414 |
+
python_sources = {
|
415 |
+
'2.7' : 'https://www.python.org/ftp/python/2.7.18/Python-2.7.18.tgz',
|
416 |
+
|
417 |
+
'3.5' : 'https://www.python.org/ftp/python/3.5.9/Python-3.5.9.tgz',
|
418 |
+
'3.6' : 'https://www.python.org/ftp/python/3.6.15/Python-3.6.15.tgz',
|
419 |
+
'3.7' : 'https://www.python.org/ftp/python/3.7.14/Python-3.7.14.tgz',
|
420 |
+
'3.8' : 'https://www.python.org/ftp/python/3.8.14/Python-3.8.14.tgz',
|
421 |
+
'3.9' : 'https://www.python.org/ftp/python/3.9.14/Python-3.9.14.tgz',
|
422 |
+
'3.10' : 'https://www.python.org/ftp/python/3.10.10/Python-3.10.10.tgz',
|
423 |
+
'3.11' : 'https://www.python.org/ftp/python/3.11.2/Python-3.11.2.tgz',
|
424 |
+
}
|
425 |
+
|
426 |
+
# map of env -> ('shell-code-before ./configure', 'extra-arguments-for-configure')
|
427 |
+
extras = {
|
428 |
+
#('mac',) : ('__PYVENV_LAUNCHER__="" MACOSX_DEPLOYMENT_TARGET={}'.format(platform_module.mac_ver()[0]), ''), # __PYVENV_LAUNCHER__ now used by-default for all platform installs
|
429 |
+
('mac',) : ('MACOSX_DEPLOYMENT_TARGET={}'.format(platform_module.mac_ver()[0]), ''),
|
430 |
+
('linux', '2.7') : ('', '--enable-unicode=ucs4'),
|
431 |
+
('ubuntu', '2.7') : ('', '--enable-unicode=ucs4'),
|
432 |
+
}
|
433 |
+
|
434 |
+
#packages = '' if (python_version[0] == '2' or python_version == '3.5' ) and platform['os'] == 'mac' else 'pip setuptools wheel' # 2.7 is now deprecated on Mac so some packages could not be installed
|
435 |
+
packages = 'setuptools'
|
436 |
+
|
437 |
+
url = python_sources[python_version]
|
438 |
+
|
439 |
+
extra = extras.get( (platform['os'],) , ('', '') )
|
440 |
+
extra = extras.get( (platform['os'], python_version) , extra)
|
441 |
+
|
442 |
+
extra = ('unset __PYVENV_LAUNCHER__ && ' + extra[0], extra[1])
|
443 |
+
|
444 |
+
options = '--with-ensurepip' #'--without-ensurepip'
|
445 |
+
signature = f'v1.5.1 url: {url}\noptions: {options}\ncompiler: {compiler}\nextra: {extra}\npackages: {packages}\n'
|
446 |
+
|
447 |
+
h = hashlib.md5(); h.update( signature.encode('utf-8', errors='backslashreplace') ); hash = h.hexdigest()
|
448 |
+
|
449 |
+
root = calculate_unique_prefix_path(platform, config) + '/python-' + python_version + '.' + compiler + '/' + hash
|
450 |
+
|
451 |
+
signature_file_name = root + '/.signature'
|
452 |
+
|
453 |
+
#activate = root + '/bin/activate'
|
454 |
+
executable = root + '/bin/python' + python_version
|
455 |
+
|
456 |
+
# if os.path.isfile(executable) and (not execute('Getting python configuration info...', '{executable}-config --prefix --includes'.format(**vars()), terminate_on_failure=False) ):
|
457 |
+
# print('found executable!')
|
458 |
+
# _, executable_version = execute('Checking Python interpreter version...', '{executable} --version'.format(**vars()), return_='tuple')
|
459 |
+
# executable_version = executable_version.split()[-1]
|
460 |
+
# else: executable_version = ''
|
461 |
+
# print('executable_version: {}'.format(executable_version))
|
462 |
+
#if executable_version != url.rpartition('Python-')[2][:-len('.tgz')]:
|
463 |
+
|
464 |
+
if os.path.isfile(signature_file_name) and open(signature_file_name).read() == signature:
|
465 |
+
#print('Install for Python-{} is detected, skipping installation procedure...'.format(python_version))
|
466 |
+
pass
|
467 |
+
|
468 |
+
else:
|
469 |
+
print( 'Installing Python-{python_version}, using {url} with extra:{extra}...'.format( **vars() ) )
|
470 |
+
|
471 |
+
if os.path.isdir(root): shutil.rmtree(root)
|
472 |
+
|
473 |
+
build_prefix = os.path.abspath(root + '/../build-python-{}'.format(python_version) )
|
474 |
+
|
475 |
+
if not os.path.isdir(root): os.makedirs(root)
|
476 |
+
if not os.path.isdir(build_prefix): os.makedirs(build_prefix)
|
477 |
+
|
478 |
+
platform_is_mac = True if platform['os'] in ['mac', 'm1'] else False
|
479 |
+
platform_is_linux = not platform_is_mac
|
480 |
+
|
481 |
+
#if False and platform['os'] == 'mac' and platform_module.machine() == 'arm64' and tuple( map(int, python_version.split('.') ) ) >= (3, 9):
|
482 |
+
if ( platform['os'] == 'mac' and python_version == '3.6' ) \
|
483 |
+
or ( platform_is_linux and python_version in ['3.10', '3.11'] ):
|
484 |
+
open_ssl_url = local_open_ssl_install(root, build_prefix, jobs)
|
485 |
+
options += f' --with-openssl={root} --with-openssl-rpath=auto'
|
486 |
+
#signature += 'OpenSSL install: ' + open_ssl_url + '\n'
|
487 |
+
|
488 |
+
archive = build_prefix + '/' + url.split('/')[-1]
|
489 |
+
build_dir = archive.rpartition('.tgz')[0]
|
490 |
+
if os.path.isdir(build_dir): shutil.rmtree(build_dir)
|
491 |
+
|
492 |
+
with open(archive, 'wb') as f:
|
493 |
+
#response = urllib2.urlopen(url)
|
494 |
+
response = urllib.request.urlopen(url)
|
495 |
+
f.write( response.read() )
|
496 |
+
|
497 |
+
#execute('Execution environment:', 'env'.format(**vars()) )
|
498 |
+
|
499 |
+
execute('Unpacking {}'.format(archive), 'cd {build_prefix} && tar -xvzf {archive}'.format(**vars()) )
|
500 |
+
|
501 |
+
#execute('Building and installing...', 'cd {} && CC={compiler} CXX={cpp_compiler} {extra[0]} ./configure {extra[1]} --prefix={root} && {extra[0]} make -j{jobs} && {extra[0]} make install'.format(build_dir, **locals()) )
|
502 |
+
execute('Configuring...', 'cd {} && CC={compiler} CXX={cpp_compiler} {extra[0]} ./configure {options} {extra[1]} --prefix={root}'.format(build_dir, **locals()) )
|
503 |
+
execute('Building...', 'cd {} && {extra[0]} make -j{jobs}'.format(build_dir, **locals()) )
|
504 |
+
execute('Installing...', 'cd {} && {extra[0]} make -j{jobs} install'.format(build_dir, **locals()) )
|
505 |
+
|
506 |
+
shutil.rmtree(build_prefix)
|
507 |
+
|
508 |
+
#execute('Updating setuptools...', f'cd {root} && {root}/bin/pip{python_version} install --upgrade setuptools wheel' )
|
509 |
+
|
510 |
+
# if 'certifi' not in packages:
|
511 |
+
# packages += ' certifi'
|
512 |
+
|
513 |
+
if packages: execute( f'Installing packages {packages}...', f'cd {root} && unset __PYVENV_LAUNCHER__ && {root}/bin/pip{python_version} install --upgrade {packages}' )
|
514 |
+
#if packages: execute( f'Installing packages {packages}...', f'cd {root} && unset __PYVENV_LAUNCHER__ && {executable} -m pip install --upgrade {packages}' )
|
515 |
+
|
516 |
+
remove_pip_and_easy_install(root) # removing all pip's and easy_install's to make sure that environment is immutable
|
517 |
+
|
518 |
+
with open(signature_file_name, 'w') as f: f.write(signature)
|
519 |
+
|
520 |
+
print( 'Installing Python-{python_version}, using {url} with extra:{extra}... Done.'.format( **vars() ) )
|
521 |
+
|
522 |
+
il = get_python_include_and_lib(executable)
|
523 |
+
|
524 |
+
return NT(
|
525 |
+
python = executable,
|
526 |
+
root = root,
|
527 |
+
python_include_dir = il.python_include_dir,
|
528 |
+
python_lib_dir = il.python_lib_dir,
|
529 |
+
version = python_version,
|
530 |
+
url = url,
|
531 |
+
platform = platform,
|
532 |
+
config = config,
|
533 |
+
hash = hash,
|
534 |
+
)
|
535 |
+
|
536 |
+
|
537 |
+
|
538 |
+
def setup_python_virtual_environment(working_dir, python_environment, packages=''):
|
539 |
+
''' Deploy Python virtual environment at working_dir
|
540 |
+
'''
|
541 |
+
|
542 |
+
python = python_environment.python
|
543 |
+
|
544 |
+
execute('Setting up Python virtual environment...', 'unset __PYVENV_LAUNCHER__ && {python} -m venv --clear {working_dir}'.format(**vars()) )
|
545 |
+
|
546 |
+
activate = f'unset __PYVENV_LAUNCHER__ && . {working_dir}/bin/activate'
|
547 |
+
|
548 |
+
bin=working_dir+'/bin'
|
549 |
+
|
550 |
+
if packages: execute('Installing packages: {}...'.format(packages), 'unset __PYVENV_LAUNCHER__ && {bin}/python {bin}/pip install --upgrade pip setuptools && {bin}/python {bin}/pip install --progress-bar off {packages}'.format(**vars()) )
|
551 |
+
#if packages: execute('Installing packages: {}...'.format(packages), '{bin}/pip{python_environment.version} install {packages}'.format(**vars()) )
|
552 |
+
|
553 |
+
return NT(activate = activate, python = bin + '/python', root = working_dir, bin = bin)
|
554 |
+
|
555 |
+
|
556 |
+
|
557 |
+
def setup_persistent_python_virtual_environment(python_environment, packages):
|
558 |
+
''' Setup _persistent_ and _immutable_ Python virtual environment which will be saved between test runs
|
559 |
+
'''
|
560 |
+
|
561 |
+
if python_environment.version.startswith('2.'):
|
562 |
+
assert not packages, f'ERROR: setup_persistent_python_virtual_environment does not support Python-2.* with non-empty package list!'
|
563 |
+
return NT(activate = ':', python = python_environment.python, root = python_environment.root, bin = python_environment.root + '/bin')
|
564 |
+
|
565 |
+
else:
|
566 |
+
#if 'certifi' not in packages: packages += ' certifi'
|
567 |
+
|
568 |
+
h = hashlib.md5()
|
569 |
+
h.update(f'v1.0.0 platform: {python_environment.platform} python_source_url: {python_environment.url} python-hash: {python_environment.hash} packages: {packages}'.encode('utf-8', errors='backslashreplace') )
|
570 |
+
hash = h.hexdigest()
|
571 |
+
|
572 |
+
prefix = calculate_unique_prefix_path(python_environment.platform, python_environment.config)
|
573 |
+
|
574 |
+
root = os.path.abspath( prefix + '/python_virtual_environments/' + '/python-' + python_environment.version + '/' + hash )
|
575 |
+
signature_file_name = root + '/.signature'
|
576 |
+
signature = f'setup_persistent_python_virtual_environment v1.0.0\npython: {python_environment.hash}\npackages: {packages}\n'
|
577 |
+
|
578 |
+
activate = f'unset __PYVENV_LAUNCHER__ && . {root}/bin/activate'
|
579 |
+
bin = f'{root}/bin'
|
580 |
+
|
581 |
+
if os.path.isfile(signature_file_name) and open(signature_file_name).read() == signature: pass
|
582 |
+
else:
|
583 |
+
if os.path.isdir(root): shutil.rmtree(root)
|
584 |
+
setup_python_virtual_environment(root, python_environment, packages=packages)
|
585 |
+
remove_pip_and_easy_install(root) # removing all pip's and easy_install's to make sure that environment is immutable
|
586 |
+
with open(signature_file_name, 'w') as f: f.write(signature)
|
587 |
+
|
588 |
+
return NT(activate = activate, python = bin + '/python', root = root, bin = bin, hash = hash)
|
589 |
+
|
590 |
+
|
591 |
+
|
592 |
+
def _get_path_to_conda_root(platform, config):
|
593 |
+
''' Perform local (prefix) install of miniconda and return NT(activate, conda_root_dir, conda)
|
594 |
+
this function is for inner use only, - to setup custom conda environment inside your test use `setup_conda_virtual_environment` defined below
|
595 |
+
'''
|
596 |
+
miniconda_sources = {
|
597 |
+
'mac' : 'https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh',
|
598 |
+
'linux' : 'https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh',
|
599 |
+
'aarch64': 'https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh',
|
600 |
+
'ubuntu' : 'https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh',
|
601 |
+
'm1' : 'https://repo.anaconda.com/miniconda/Miniconda3-py38_4.10.1-MacOSX-arm64.sh',
|
602 |
+
}
|
603 |
+
|
604 |
+
conda_sources = {
|
605 |
+
'mac' : 'https://repo.continuum.io/archive/Anaconda3-2018.12-MacOSX-x86_64.sh',
|
606 |
+
'linux' : 'https://repo.continuum.io/archive/Anaconda3-2018.12-Linux-x86_64.sh',
|
607 |
+
'ubuntu' : 'https://repo.continuum.io/archive/Anaconda3-2018.12-Linux-x86_64.sh',
|
608 |
+
}
|
609 |
+
|
610 |
+
#platform_os = 'm1' if platform_module.machine() == 'arm64' else platform['os']
|
611 |
+
#url = miniconda_sources[ platform_os ]
|
612 |
+
|
613 |
+
platform_os = platform['os']
|
614 |
+
for o in 'alpine centos ubuntu'.split():
|
615 |
+
if platform_os.startswith(o): platform_os = 'linux'
|
616 |
+
|
617 |
+
url = miniconda_sources[platform_os]
|
618 |
+
|
619 |
+
version = '1'
|
620 |
+
channels = '' # conda-forge
|
621 |
+
|
622 |
+
#packages = ['conda-build gcc libgcc', 'libgcc=5.2.0'] # libgcc installs is workaround for "Anaconda libstdc++.so.6: version `GLIBCXX_3.4.20' not found", see: https://stackoverflow.com/questions/48453497/anaconda-libstdc-so-6-version-glibcxx-3-4-20-not-found
|
623 |
+
#packages = ['conda-build gcc'] # libgcc installs is workaround for "Anaconda libstdc++.so.6: version `GLIBCXX_3.4.20' not found", see: https://stackoverflow.com/questions/48453497/anaconda-libstdc-so-6-version-glibcxx-3-4-20-not-found
|
624 |
+
packages = ['conda-build anaconda-client conda-verify',]
|
625 |
+
|
626 |
+
signature = f'url: {url}\nversion: {version}\channels: {channels}\npackages: {packages}\n'
|
627 |
+
|
628 |
+
root = calculate_unique_prefix_path(platform, config) + '/conda'
|
629 |
+
|
630 |
+
signature_file_name = root + '/.signature'
|
631 |
+
|
632 |
+
# presense of __PYVENV_LAUNCHER__,PYTHONHOME, PYTHONPATH sometimes confuse Python so we have to unset them
|
633 |
+
unset = 'unset __PYVENV_LAUNCHER__ && unset PYTHONHOME && unset PYTHONPATH'
|
634 |
+
activate = unset + ' && . ' + root + '/bin/activate'
|
635 |
+
|
636 |
+
executable = root + '/bin/conda'
|
637 |
+
|
638 |
+
|
639 |
+
if os.path.isfile(signature_file_name) and open(signature_file_name).read() == signature:
|
640 |
+
print( f'Install for MiniConda is detected, skipping installation procedure...' )
|
641 |
+
|
642 |
+
else:
|
643 |
+
print( f'Installing MiniConda, using {url}...' )
|
644 |
+
|
645 |
+
if os.path.isdir(root): shutil.rmtree(root)
|
646 |
+
|
647 |
+
build_prefix = os.path.abspath(root + f'/../build-conda' )
|
648 |
+
|
649 |
+
#if not os.path.isdir(root): os.makedirs(root)
|
650 |
+
if not os.path.isdir(build_prefix): os.makedirs(build_prefix)
|
651 |
+
|
652 |
+
archive = build_prefix + '/' + url.split('/')[-1]
|
653 |
+
|
654 |
+
with open(archive, 'wb') as f:
|
655 |
+
response = urllib.request.urlopen(url)
|
656 |
+
f.write( response.read() )
|
657 |
+
|
658 |
+
execute('Installing conda...', f'cd {build_prefix} && {unset} && bash {archive} -b -p {root}' )
|
659 |
+
|
660 |
+
# conda update --yes --quiet -n base -c defaults conda
|
661 |
+
|
662 |
+
if channels: execute(f'Adding extra channles {channels}...', f'cd {build_prefix} && {activate} && conda config --add channels {channels}' )
|
663 |
+
|
664 |
+
for p in packages: execute(f'Installing conda packages: {p}...', f'cd {build_prefix} && {activate} && conda install --quiet --yes {p}' )
|
665 |
+
|
666 |
+
shutil.rmtree(build_prefix)
|
667 |
+
|
668 |
+
with open(signature_file_name, 'w') as f: f.write(signature)
|
669 |
+
|
670 |
+
print( f'Installing MiniConda, using {url}... Done.' )
|
671 |
+
|
672 |
+
execute(f'Updating conda base...', f'{activate} && conda update --all --yes' )
|
673 |
+
return NT(conda=executable, root=root, activate=activate, url=url)
|
674 |
+
|
675 |
+
|
676 |
+
|
677 |
+
def setup_conda_virtual_environment(working_dir, platform, config, packages=''):
|
678 |
+
''' Deploy Conda virtual environment at working_dir
|
679 |
+
'''
|
680 |
+
conda_root_env = _get_path_to_conda_root(platform, config)
|
681 |
+
activate = conda_root_env.activate
|
682 |
+
|
683 |
+
python_version = platform.get('python', DEFAULT_PYTHON_VERSION)
|
684 |
+
|
685 |
+
prefix = os.path.abspath( working_dir + '/.conda-python-' + python_version )
|
686 |
+
|
687 |
+
command_line = f'conda create --quiet --yes --prefix {prefix} python={python_version}'
|
688 |
+
|
689 |
+
execute( f'Setting up Conda for Python-{python_version} virtual environment...', f'cd {working_dir} && {activate} && ( {command_line} || ( conda clean --yes && {command_line} ) )' )
|
690 |
+
|
691 |
+
activate = f'{activate} && conda activate {prefix}'
|
692 |
+
|
693 |
+
if packages: execute( f'Setting up extra packages {packages}...', f'cd {working_dir} && {activate} && conda install --quiet --yes {packages}' )
|
694 |
+
|
695 |
+
python = prefix + '/bin/python' + python_version
|
696 |
+
|
697 |
+
il = get_python_include_and_lib(python)
|
698 |
+
|
699 |
+
return NT(
|
700 |
+
activate = activate,
|
701 |
+
root = prefix,
|
702 |
+
python = python,
|
703 |
+
python_include_dir = il.python_include_dir,
|
704 |
+
python_lib_dir = il.python_lib_dir,
|
705 |
+
version = python_version,
|
706 |
+
activate_base = conda_root_env.activate,
|
707 |
+
url = prefix, # conda_root_env.url,
|
708 |
+
platform=platform,
|
709 |
+
config=config,
|
710 |
+
)
|
711 |
+
|
712 |
+
|
713 |
+
|
714 |
+
class FileLock():
|
715 |
+
''' Implementation of file-lock object that could be use with Python `with` statement
|
716 |
+
'''
|
717 |
+
|
718 |
+
def __init__(self, file_name):
|
719 |
+
self.locked = False
|
720 |
+
self.file_name = file_name
|
721 |
+
|
722 |
+
|
723 |
+
def __enter__(self):
|
724 |
+
if not self.locked: self.acquire()
|
725 |
+
return self
|
726 |
+
|
727 |
+
|
728 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
729 |
+
if self.locked: self.release()
|
730 |
+
|
731 |
+
|
732 |
+
def __del__(self):
|
733 |
+
self.release()
|
734 |
+
|
735 |
+
|
736 |
+
def acquire(self):
|
737 |
+
while True:
|
738 |
+
try:
|
739 |
+
os.close( os.open(self.file_name, os.O_CREAT | os.O_EXCL, mode=0o600) )
|
740 |
+
self.locked = True
|
741 |
+
break
|
742 |
+
|
743 |
+
except FileExistsError as e:
|
744 |
+
time.sleep(60)
|
745 |
+
|
746 |
+
|
747 |
+
def release(self):
|
748 |
+
if self.locked:
|
749 |
+
os.remove(self.file_name)
|
750 |
+
self.locked = False
|
751 |
+
|
752 |
+
|
753 |
+
|
754 |
+
def convert_submodule_urls_from_ssh_to_https(repository_root):
|
755 |
+
''' switching submodules URL to HTTPS so we can clone without SSH key
|
756 |
+
'''
|
757 |
+
with open(f'{repository_root}/.gitmodules') as f: m = f.read()
|
758 |
+
with open(f'{repository_root}/.gitmodules', 'w') as f:
|
759 |
+
f.write(
|
760 |
+
m
|
761 |
+
.replace('url = git@github.com:', 'url = https://github.com/')
|
762 |
+
.replace('url = ../../../', 'url = https://github.com/RosettaCommons/')
|
763 |
+
.replace('url = ../../', 'url = https://github.com/RosettaCommons/')
|
764 |
+
.replace('url = ../', 'url = https://github.com/RosettaCommons/')
|
765 |
+
)
|
.rosetta-ci/tests/rfd.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# :noTabs=true:
|
4 |
+
|
5 |
+
# (c) Copyright Rosetta Commons Member Institutions.
|
6 |
+
# (c) This file is part of the Rosetta software suite and is made available under license.
|
7 |
+
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
|
8 |
+
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
|
9 |
+
# (c) addressed to University of Washington CoMotion, email: license@uw.edu.
|
10 |
+
|
11 |
+
## @file rfd.py
|
12 |
+
## @brief main test files for RFdiffusion
|
13 |
+
## @author Sergey Lyskov
|
14 |
+
|
15 |
+
|
16 |
+
import imp
|
17 |
+
imp.load_source(__name__, '/'.join(__file__.split('/')[:-1]) + '/__init__.py') # A bit of Python magic here, what we trying to say is this: from __init__ import *, but init is calculated from file location
|
18 |
+
|
19 |
+
_api_version_ = '1.0'
|
20 |
+
|
21 |
+
import os, tempfile, shutil
|
22 |
+
import urllib.request
|
23 |
+
|
24 |
+
|
25 |
+
_models_urls_ = '''
|
26 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/6f5902ac237024bdd0c176cb93063dc4/Base_ckpt.pt
|
27 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/e29311f6f1bf1af907f9ef9f44b8328b/Complex_base_ckpt.pt
|
28 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt
|
29 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/74f51cfb8b440f50d70878e05361d8f0/InpaintSeq_ckpt.pt
|
30 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/76d00716416567174cdb7ca96e208296/InpaintSeq_Fold_ckpt.pt
|
31 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/5532d2e1f3a4738decd58b19d633b3c3/ActiveSite_ckpt.pt
|
32 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/12fc204edeae5b57713c5ad7dcb97d39/Base_epoch8_ckpt.pt
|
33 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/f572d396fae9206628714fb2ce00f72e/Complex_beta_ckpt.pt
|
34 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/1befcb9b28e2f778f53d47f18b7597fa/RF_structure_prediction_weights.pt
|
35 |
+
'''.split()
|
36 |
+
|
37 |
+
|
38 |
+
def run_main_test_suite(repository_root, working_dir, platform, config, debug):
|
39 |
+
full_log = ''
|
40 |
+
|
41 |
+
python_environment = local_python_install(platform, config)
|
42 |
+
|
43 |
+
models_dir = repository_root + '/models'
|
44 |
+
if not os.path.isdir(models_dir): os.makedirs(models_dir)
|
45 |
+
|
46 |
+
for url in _models_urls_:
|
47 |
+
file_name = models_dir + '/' + url.split('/')[-1]
|
48 |
+
tmp_file_name = file_name + '.tmp'
|
49 |
+
if not os.path.isfile(file_name):
|
50 |
+
print(f'downloading {url}...')
|
51 |
+
full_log += f'downloading {url}...\n'
|
52 |
+
urllib.request.urlretrieve(url, tmp_file_name)
|
53 |
+
os.rename(tmp_file_name, file_name)
|
54 |
+
|
55 |
+
execute('unpacking ppi scaffolds...', f'cd {repository_root} && tar -xvf examples/ppi_scaffolds_subset.tar.gz -C examples')
|
56 |
+
|
57 |
+
with tempfile.TemporaryDirectory(dir=working_dir) as tmpdirname:
|
58 |
+
# tmpdirname = working_dir+'/.ve'
|
59 |
+
# if True:
|
60 |
+
|
61 |
+
#ve = setup_persistent_python_virtual_environment(python_environment, packages='numpy torch omegaconf scipy opt_einsum dgl')
|
62 |
+
#ve = setup_python_virtual_environment(working_dir+'/.ve', python_environment, packages='numpy torch omegaconf scipy opt_einsum dgl e3nn icecream pyrsistent wandb pynvml decorator jedi hydra-core')
|
63 |
+
ve = setup_python_virtual_environment(tmpdirname, python_environment, packages='numpy torch omegaconf scipy opt_einsum dgl e3nn icecream pyrsistent wandb pynvml decorator jedi hydra-core')
|
64 |
+
|
65 |
+
execute('Installing local se3-transformer package...', f'cd {repository_root}/env/SE3Transformer && {ve.bin}/pip3 install --editable .')
|
66 |
+
execute('Installing RFdiffusion package...', f'cd {repository_root} && {ve.bin}/pip3 install --editable .')
|
67 |
+
|
68 |
+
#res, output = execute('running unit tests...', f'{ve.activate} && cd {repository_root} && python -m unittest', return_='tuple', add_message_and_command_line_to_output=True)
|
69 |
+
#res, output = execute('running unit tests...', f'cd {repository_root} && {ve.bin}/pytest', return_='tuple')
|
70 |
+
|
71 |
+
|
72 |
+
results_file = f'{repository_root}/tests/.results.json'
|
73 |
+
if os.path.isfile(results_file): os.remove(results_file)
|
74 |
+
|
75 |
+
res, output = execute('running RFdiffusion tests...', f'{ve.activate} && cd {repository_root}/tests && python test_diffusion.py', return_='tuple', add_message_and_command_line_to_output=True)
|
76 |
+
|
77 |
+
if os.path.isfile(results_file):
|
78 |
+
with open(results_file) as f: sub_tests_reults = json.load(f)
|
79 |
+
|
80 |
+
state = _S_passed_
|
81 |
+
for r in sub_tests_reults.values():
|
82 |
+
if r[_StateKey_] == _S_failed_:
|
83 |
+
state = _S_failed_
|
84 |
+
break
|
85 |
+
|
86 |
+
else:
|
87 |
+
sub_tests_reults = {}
|
88 |
+
output += '\n\nEmpty sub-test results, marking test as `failed`...'
|
89 |
+
state = _S_failed_
|
90 |
+
|
91 |
+
shutil.move(f'{repository_root}/tests/outputs', f'{working_dir}/outputs')
|
92 |
+
|
93 |
+
for d in os.listdir(f'{repository_root}/tests'):
|
94 |
+
p = f'{repository_root}/tests/{d}'
|
95 |
+
if d.startswith('tests_') and os.path.isdir(p): shutil.rmtree(p)
|
96 |
+
|
97 |
+
results = {
|
98 |
+
_StateKey_ : state,
|
99 |
+
_LogKey_ : full_log + '\n' + output,
|
100 |
+
_ResultsKey_ : {
|
101 |
+
_TestsKey_ : sub_tests_reults,
|
102 |
+
},
|
103 |
+
}
|
104 |
+
|
105 |
+
return results
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def run(test, repository_root, working_dir, platform, config, hpc_driver=None, verbose=False, debug=False):
|
110 |
+
if test == '': return run_main_test_suite(repository_root=repository_root, working_dir=working_dir, platform=platform, config=config, debug=debug)
|
111 |
+
else: raise BenchmarkError('Unknow scripts test: {}!'.format(test))
|
.rosetta-ci/tests/self.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# self test suite
|
2 |
+
These tests are design to help debug interface between testing server and Rosetta testing scripts
|
3 |
+
|
4 |
+
-----
|
5 |
+
### python
|
6 |
+
Test Python platform support and functionality of local and persistent Python virtual environments
|
.rosetta-ci/tests/self.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# :noTabs=true:
|
4 |
+
|
5 |
+
# (c) Copyright Rosetta Commons Member Institutions.
|
6 |
+
# (c) This file is part of the Rosetta software suite and is made available under license.
|
7 |
+
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
|
8 |
+
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
|
9 |
+
# (c) addressed to University of Washington CoMotion, email: license@uw.edu.
|
10 |
+
|
11 |
+
## @file dummy.py
|
12 |
+
## @brief self-test and debug-aids tests
|
13 |
+
## @author Sergey Lyskov
|
14 |
+
|
15 |
+
import os, os.path, shutil, re, string
|
16 |
+
import json
|
17 |
+
|
18 |
+
import random
|
19 |
+
|
20 |
+
import imp
|
21 |
+
imp.load_source(__name__, '/'.join(__file__.split('/')[:-1]) + '/__init__.py') # A bit of Python magic here, what we trying to say is this: from __init__ import *, but init is calculated from file location
|
22 |
+
|
23 |
+
_api_version_ = '1.0'
|
24 |
+
|
25 |
+
|
26 |
+
def run_state_test(repository_root, working_dir, platform, config):
|
27 |
+
revision_id = config['revision']
|
28 |
+
states = (_S_passed_, _S_failed_, _S_build_failed_, _S_script_failed_)
|
29 |
+
state = states[revision_id % len(states)]
|
30 |
+
|
31 |
+
return {_StateKey_ : state, _ResultsKey_ : {}, _LogKey_ : f'run_state_test: setting test state to {state!r}...' }
|
32 |
+
|
33 |
+
|
34 |
+
sub_test_description_template = '''\
|
35 |
+
# subtests_test test suite
|
36 |
+
These sub-test description is generated for 3/4 of sub-tests
|
37 |
+
|
38 |
+
-----
|
39 |
+
### {name}
|
40 |
+
The warm time, had already disappeared like dust. Broken rain, fragment of light shadow, bring more pain to my heart...
|
41 |
+
-----
|
42 |
+
'''
|
43 |
+
|
44 |
+
def run_subtests_test(repository_root, working_dir, platform, config):
|
45 |
+
tests = {}
|
46 |
+
for i in range(16):
|
47 |
+
name = f's-{i:02}'
|
48 |
+
log = ('x'*63 + '\n') * 16 * 256 * i
|
49 |
+
s = i % 3
|
50 |
+
if s == 0: state = _S_passed_
|
51 |
+
elif s == 1: state = _S_failed_
|
52 |
+
else: state = _S_script_failed_
|
53 |
+
|
54 |
+
if i % 4:
|
55 |
+
os.mkdir( f'{working_dir}/{name}' )
|
56 |
+
with open(f'{working_dir}/{name}/description.md', 'w') as f: f.write( sub_test_description_template.format(**vars()) )
|
57 |
+
|
58 |
+
with open( f'{working_dir}/{name}/fantome.txt', 'w') as f: f.write('No one wants to hear the sequel to a fairytale\n')
|
59 |
+
|
60 |
+
tests[name] = { _StateKey_ : state, _LogKey_ : log, }
|
61 |
+
|
62 |
+
test_log = ('*'*63 + '\n') * 16 * 1024 * 16
|
63 |
+
return {_StateKey_ : _S_failed_, _ResultsKey_ : {_TestsKey_: tests}, _LogKey_ : test_log }
|
64 |
+
|
65 |
+
|
66 |
+
def run_regression_test(repository_root, working_dir, platform, config):
|
67 |
+
const = 'const'
|
68 |
+
volatile = 'volatile'
|
69 |
+
new = ''.join( random.sample( string.ascii_letters + string.digits, 8) )
|
70 |
+
oversized = 'oversized'
|
71 |
+
|
72 |
+
sub_tests = [const, volatile, new]
|
73 |
+
|
74 |
+
const_dir = working_dir + '/' + const
|
75 |
+
os.mkdir(const_dir)
|
76 |
+
with open(const_dir + '/const_data', 'w') as f: f.write( '\n'.join( (str(i) for i in range(32) ) ) )
|
77 |
+
|
78 |
+
volatile_dir = working_dir + '/' + volatile
|
79 |
+
os.mkdir(volatile_dir)
|
80 |
+
with open(volatile_dir + '/const_data', 'w') as f: f.write( '\n'.join( (str(i) for i in range(32, 64) ) ) )
|
81 |
+
with open(volatile_dir + '/volatile_data', 'w') as f: f.write( '\n'.join( ( ''.join(random.sample( string.ascii_letters + string.digits, 8) ) for i in range(32) ) ) )
|
82 |
+
|
83 |
+
new_dir = working_dir + '/' + new
|
84 |
+
os.mkdir(new_dir)
|
85 |
+
with open(new_dir + '/data', 'w') as f: f.write( '\n'.join( (str(i) for i in range(64)) ) )
|
86 |
+
|
87 |
+
|
88 |
+
new_dir = working_dir + '/' + oversized
|
89 |
+
os.mkdir(new_dir)
|
90 |
+
with open(new_dir + '/large', 'w') as f: f.write( ('x'*63 + '\n')*16*1024*256 +'extra')
|
91 |
+
|
92 |
+
return {_StateKey_ : _S_queued_for_comparison_, _ResultsKey_ : {}, _LogKey_ : f'sub-tests: {sub_tests!r}' }
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
def run_release_test(repository_root, working_dir, platform, config):
|
97 |
+
release_root = config['mounts'].get('release_root')
|
98 |
+
|
99 |
+
branch = config['branch']
|
100 |
+
revision = config['revision']
|
101 |
+
|
102 |
+
assert release_root, "config['release_root'] must be set!"
|
103 |
+
|
104 |
+
release_path = f'{release_root}/dummy'
|
105 |
+
|
106 |
+
if not os.path.isdir(release_path): os.makedirs(release_path)
|
107 |
+
|
108 |
+
with open(f'{release_path}/{branch}-{revision}.txt', 'w') as f: f.write('dummy release file\n')
|
109 |
+
|
110 |
+
return {_StateKey_ : _S_passed_, _ResultsKey_ : {}, _LogKey_ : f'Config release root set to: {release_root}'}
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
def run_python_test(repository_root, working_dir, platform, config):
|
115 |
+
|
116 |
+
import zlib, ssl
|
117 |
+
|
118 |
+
python_environment = local_python_install(platform, config)
|
119 |
+
|
120 |
+
if platform['python'][0] == '2': pass
|
121 |
+
else:
|
122 |
+
|
123 |
+
if platform['os'] == 'mac' and int( platform['python'].split('.')[1] ) > 6 :
|
124 |
+
# SSL certificate test
|
125 |
+
import urllib.request; urllib.request.urlopen('https://benchmark.graylab.jhu.edu')
|
126 |
+
|
127 |
+
ves = [
|
128 |
+
setup_persistent_python_virtual_environment(python_environment, packages='colr dice xdice pdp11games'),
|
129 |
+
setup_python_virtual_environment(working_dir, python_environment, packages='colr dice xdice pdp11games'),
|
130 |
+
]
|
131 |
+
|
132 |
+
for ve in ves:
|
133 |
+
commands = [
|
134 |
+
'import colr, dice, xdice, pdp11games',
|
135 |
+
]
|
136 |
+
|
137 |
+
if platform['os'] == 'mac' and int( platform['python'].split('.')[1] ) > 6 :
|
138 |
+
# SSL certificate test
|
139 |
+
commands.append('import urllib.request; urllib.request.urlopen("https://benchmark.graylab.jhu.edu/queue")')
|
140 |
+
|
141 |
+
for command in commands:
|
142 |
+
execute('Testing local Python virtual enviroment...', f"{ve.activate} && {ve.python} -c '{command}'")
|
143 |
+
execute('Testing local Python virtual enviroment...', f"{ve.activate} && python -c '{command}'")
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
return {_StateKey_ : _S_passed_, _ResultsKey_ : {}, _LogKey_ : f'Done!'}
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
def compare(test, results, files_path, previous_results, previous_files_path):
|
152 |
+
"""
|
153 |
+
Compare the results of two tests run (new vs. previous) for regression test
|
154 |
+
Take two dict and two paths
|
155 |
+
Must return standard dict with results
|
156 |
+
|
157 |
+
:param test: str
|
158 |
+
:param results: dict
|
159 |
+
:param files_path: str
|
160 |
+
:param previous_results: dict
|
161 |
+
:param previous_files_path: str
|
162 |
+
:rtype: dict
|
163 |
+
"""
|
164 |
+
ignore_files = []
|
165 |
+
|
166 |
+
results = dict(tests={}, summary=dict(total=0, failed=0, failed_tests=[])) # , config={}
|
167 |
+
|
168 |
+
if previous_files_path:
|
169 |
+
for test in os.listdir(files_path):
|
170 |
+
if os.path.isdir(files_path + '/' + test):
|
171 |
+
exclude = ''.join([' --exclude="{}"'.format(f) for f in ignore_files] ) + ' --exclude="*.ignore"'
|
172 |
+
res, brief_diff = execute('Comparing {}...'.format(test), 'diff -rq {exclude} {0}/{test} {1}/{test}'.format(previous_files_path, files_path, test=test, exclude=exclude), return_='tuple')
|
173 |
+
res, full_diff = execute('Comparing {}...'.format(test), 'diff -r {exclude} {0}/{test} {1}/{test}'.format(previous_files_path, files_path, test=test, exclude=exclude), return_='tuple')
|
174 |
+
diff = 'Brief Diff:\n' + brief_diff + ( ('\n\nFull Diff:\n' + full_diff[:1024*1024*1]) if full_diff != brief_diff else '' )
|
175 |
+
|
176 |
+
state = _S_failed_ if res else _S_passed_
|
177 |
+
results['tests'][test] = {_StateKey_: state, _LogKey_: diff if state != _S_passed_ else ''}
|
178 |
+
|
179 |
+
results['summary']['total'] += 1
|
180 |
+
if res: results['summary']['failed'] += 1; results['summary']['failed_tests'].append(test)
|
181 |
+
|
182 |
+
else: # no previous tests case, returning 'passed' for all sub_tests
|
183 |
+
for test in os.listdir(files_path):
|
184 |
+
if os.path.isdir(files_path + '/' + test):
|
185 |
+
results['tests'][test] = {_StateKey_: _S_passed_, _LogKey_: 'First run, no previous results available. Skipping comparison...\n'}
|
186 |
+
results['summary']['total'] += 1
|
187 |
+
|
188 |
+
for test in os.listdir(files_path):
|
189 |
+
if os.path.isdir(files_path + '/' + test):
|
190 |
+
if os.path.isfile(files_path+'/'+test+'/.test_did_not_run.log') or os.path.isfile(files_path+'/'+test+'/.test_got_timeout_kill.log'):
|
191 |
+
results['tests'][test][_StateKey_] = _S_script_failed_
|
192 |
+
results['tests'][test][_LogKey_] += '\nCompare(...): Marking as "Script failed" due to presense of .test_did_not_run.log or .test_got_timeout_kill.log file!\n'
|
193 |
+
if test not in results['summary']['failed_tests']:
|
194 |
+
results['summary']['failed'] += 1
|
195 |
+
results['summary']['failed_tests'].append(test)
|
196 |
+
|
197 |
+
state = _S_failed_ if results['summary']['failed'] else _S_passed_
|
198 |
+
|
199 |
+
return {_StateKey_: state, _LogKey_: 'Comparison dummy log...', _ResultsKey_: results}
|
200 |
+
|
201 |
+
|
202 |
+
def run(test, repository_root, working_dir, platform, config, hpc_driver=None, verbose=False, debug=False):
|
203 |
+
if test == 'state': return run_state_test (repository_root, working_dir, platform, config)
|
204 |
+
elif test == 'regression': return run_regression_test (repository_root, working_dir, platform, config)
|
205 |
+
elif test == 'subtests': return run_subtests_test (repository_root, working_dir, platform, config)
|
206 |
+
elif test == 'release': return run_release_test (repository_root, working_dir, platform, config)
|
207 |
+
elif test == 'python': return run_python_test (repository_root, working_dir, platform, config)
|
208 |
+
|
209 |
+
else: raise BenchmarkError(f'Dummy test script does not support run with test={test!r}!')
|
END
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"retCode":"100",
|
3 |
+
"retData":null,
|
4 |
+
"retMsg":"操作成功",
|
5 |
+
"retTime":"2022-11-05 22:20:09",
|
6 |
+
"success":true
|
7 |
+
}
|
LICENSE
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD License
|
2 |
+
|
3 |
+
Copyright (c) 2023 University of Washington. Developed at the Institute for
|
4 |
+
Protein Design by Joseph Watson, David Juergens, Nathaniel Bennett, Brian Trippe
|
5 |
+
and Jason Yim
|
6 |
+
|
7 |
+
Redistribution and use in source and binary forms, with or without
|
8 |
+
modification, are permitted provided that the following conditions are met:
|
9 |
+
|
10 |
+
Redistributions of source code must retain the above copyright notice, this
|
11 |
+
list of conditions and the following disclaimer.
|
12 |
+
|
13 |
+
Redistributions in binary form must reproduce the above copyright notice, this
|
14 |
+
list of conditions and the following disclaimer in the documentation and/or
|
15 |
+
other materials provided with the distribution.
|
16 |
+
|
17 |
+
Neither the name of the University of Washington nor the names of its
|
18 |
+
contributors may be used to endorse or promote products derived from this
|
19 |
+
software without specific prior written permission.
|
20 |
+
|
21 |
+
THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON AND CONTRIBUTORS “AS
|
22 |
+
IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
23 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24 |
+
DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF WASHINGTON OR CONTRIBUTORS BE
|
25 |
+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
26 |
+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
|
27 |
+
GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
|
28 |
+
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
29 |
+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
|
30 |
+
OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
CHANGED
@@ -1,3 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RF*diffusion*
|
2 |
+
|
3 |
+
<!--
|
4 |
+
<img width="1115" alt="Screen Shot 2023-01-19 at 5 56 33 PM" src="https://user-images.githubusercontent.com/56419265/213588200-f8f44dba-276e-4dd2-b844-15acc441458d.png">
|
5 |
+
-->
|
6 |
+
<p align="center">
|
7 |
+
<img src="./img/diffusion_protein_gradient_2.jpg" alt="alt text" width="1100px" align="middle"/>
|
8 |
+
</p>
|
9 |
+
|
10 |
+
*Image: Ian C. Haydon / UW Institute for Protein Design*
|
11 |
+
|
12 |
+
## Description
|
13 |
+
|
14 |
+
RFdiffusion is an open source method for structure generation, with or without conditional information (a motif, target etc). It can perform a whole range of protein design challenges as we have outlined in [the RFdiffusion paper](https://www.biorxiv.org/content/10.1101/2022.12.09.519842v1).
|
15 |
+
|
16 |
+
**Things Diffusion can do**
|
17 |
+
- Motif Scaffolding
|
18 |
+
- Unconditional protein generation
|
19 |
+
- Symmetric unconditional generation (cyclic, dihedral and tetrahedral symmetries currently implemented, more coming!)
|
20 |
+
- Symmetric motif scaffolding
|
21 |
+
- Binder design
|
22 |
+
- Design diversification ("partial diffusion", sampling around a design)
|
23 |
+
|
24 |
+
----
|
25 |
+
|
26 |
+
# Table of contents
|
27 |
+
|
28 |
+
- [RF*diffusion*](#rfdiffusion)
|
29 |
+
- [Description](#description)
|
30 |
+
- [Table of contents](#table-of-contents)
|
31 |
+
- [Getting started / installation](#getting-started--installation)
|
32 |
+
- [Conda Install SE3-Transformer](#conda-install-se3-transformer)
|
33 |
+
- [Get PPI Scaffold Examples](#get-ppi-scaffold-examples)
|
34 |
+
- [Usage](#usage)
|
35 |
+
- [Running the diffusion script](#running-the-diffusion-script)
|
36 |
+
- [Basic execution - an unconditional monomer](#basic-execution---an-unconditional-monomer)
|
37 |
+
- [Motif Scaffolding](#motif-scaffolding)
|
38 |
+
- [The "active site" model holds very small motifs in place](#the-active-site-model-holds-very-small-motifs-in-place)
|
39 |
+
- [The `inpaint_seq` flag](#the-inpaint_seq-flag)
|
40 |
+
- [A note on `diffuser.T`](#a-note-on-diffusert)
|
41 |
+
- [Partial diffusion](#partial-diffusion)
|
42 |
+
- [Binder Design](#binder-design)
|
43 |
+
- [Practical Considerations for Binder Design](#practical-considerations-for-binder-design)
|
44 |
+
- [Fold Conditioning](#fold-conditioning)
|
45 |
+
- [Generation of Symmetric Oligomers](#generation-of-symmetric-oligomers)
|
46 |
+
- [Using Auxiliary Potentials](#using-auxiliary-potentials)
|
47 |
+
- [Symmetric Motif Scaffolding.](#symmetric-motif-scaffolding)
|
48 |
+
- [A Note on Model Weights](#a-note-on-model-weights)
|
49 |
+
- [Things you might want to play with at inference time](#things-you-might-want-to-play-with-at-inference-time)
|
50 |
+
- [Understanding the output files](#understanding-the-output-files)
|
51 |
+
- [Docker](#docker)
|
52 |
+
- [Conclusion](#conclusion)
|
53 |
+
|
54 |
+
# Getting started / installation
|
55 |
+
|
56 |
+
Thanks to Sergey Ovchinnikov, RFdiffusion is available as a [Google Colab Notebook](https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/rf/examples/diffusion.ipynb) if you would like to run it there!
|
57 |
+
|
58 |
+
We strongly recommend reading this README carefully before getting started with RFdiffusion, and working through some of the examples in the Colab Notebook.
|
59 |
+
|
60 |
+
If you want to set up RFdiffusion locally, follow the steps below:
|
61 |
+
|
62 |
+
To get started using RFdiffusion, clone the repo:
|
63 |
+
```
|
64 |
+
git clone https://github.com/RosettaCommons/RFdiffusion.git
|
65 |
+
```
|
66 |
+
|
67 |
+
You'll then need to download the model weights into the RFDiffusion directory.
|
68 |
+
```
|
69 |
+
cd RFdiffusion
|
70 |
+
mkdir models && cd models
|
71 |
+
wget http://files.ipd.uw.edu/pub/RFdiffusion/6f5902ac237024bdd0c176cb93063dc4/Base_ckpt.pt
|
72 |
+
wget http://files.ipd.uw.edu/pub/RFdiffusion/e29311f6f1bf1af907f9ef9f44b8328b/Complex_base_ckpt.pt
|
73 |
+
wget http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt
|
74 |
+
wget http://files.ipd.uw.edu/pub/RFdiffusion/74f51cfb8b440f50d70878e05361d8f0/InpaintSeq_ckpt.pt
|
75 |
+
wget http://files.ipd.uw.edu/pub/RFdiffusion/76d00716416567174cdb7ca96e208296/InpaintSeq_Fold_ckpt.pt
|
76 |
+
wget http://files.ipd.uw.edu/pub/RFdiffusion/5532d2e1f3a4738decd58b19d633b3c3/ActiveSite_ckpt.pt
|
77 |
+
wget http://files.ipd.uw.edu/pub/RFdiffusion/12fc204edeae5b57713c5ad7dcb97d39/Base_epoch8_ckpt.pt
|
78 |
+
|
79 |
+
Optional:
|
80 |
+
wget http://files.ipd.uw.edu/pub/RFdiffusion/f572d396fae9206628714fb2ce00f72e/Complex_beta_ckpt.pt
|
81 |
+
|
82 |
+
# original structure prediction weights
|
83 |
+
wget http://files.ipd.uw.edu/pub/RFdiffusion/1befcb9b28e2f778f53d47f18b7597fa/RF_structure_prediction_weights.pt
|
84 |
+
```
|
85 |
+
|
86 |
+
|
87 |
+
### Conda Install SE3-Transformer
|
88 |
+
|
89 |
+
Ensure that you have either [Anaconda or Miniconda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) installed.
|
90 |
+
|
91 |
+
You also need to install [NVIDIA's implementation of SE(3)-Transformers](https://developer.nvidia.com/blog/accelerating-se3-transformers-training-using-an-nvidia-open-source-model-implementation/) Here is how to install the NVIDIA SE(3)-Transformer code:
|
92 |
+
|
93 |
+
```
|
94 |
+
conda env create -f env/SE3nv.yml
|
95 |
+
|
96 |
+
conda activate SE3nv
|
97 |
+
cd env/SE3Transformer
|
98 |
+
pip install --no-cache-dir -r requirements.txt
|
99 |
+
python setup.py install
|
100 |
+
cd ../.. # change into the root directory of the repository
|
101 |
+
pip install -e . # install the rfdiffusion module from the root of the repository
|
102 |
+
```
|
103 |
+
Anytime you run diffusion you should be sure to activate this conda environment by running the following command:
|
104 |
+
```
|
105 |
+
conda activate SE3nv
|
106 |
+
```
|
107 |
+
Total setup should take less than 30 minutes on a standard desktop computer.
|
108 |
+
Note: Due to the variation in GPU types and drivers that users have access to, we are not able to make one environment that will run on all setups. As such, we are only providing a yml file with support for CUDA 11.1 and leaving it to each user to customize it to work on their setups. This customization will involve changing the cudatoolkit and (possibly) the PyTorch version specified in the yml file.
|
109 |
+
|
110 |
---
|
111 |
+
|
112 |
+
### Get PPI Scaffold Examples
|
113 |
+
|
114 |
+
To run the scaffolded protein binder design (PPI) examples, we have provided some example scaffold files (`examples/ppi_scaffolds_subset.tar.gz`).
|
115 |
+
You'll need to untar this:
|
116 |
+
```
|
117 |
+
tar -xvf examples/ppi_scaffolds_subset.tar.gz -C examples/
|
118 |
+
```
|
119 |
+
|
120 |
+
We will explain what these files are and how to use them in the Fold Conditioning section.
|
121 |
+
|
122 |
+
----
|
123 |
+
|
124 |
+
|
125 |
+
# Usage
|
126 |
+
In this section we will demonstrate how to run diffusion.
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
<p align="center">
|
131 |
+
<img src="./img/main.png" alt="alt text" width="1100px" align="middle"/>
|
132 |
+
</p>
|
133 |
+
|
134 |
+
|
135 |
+
### Running the diffusion script
|
136 |
+
The actual script you will execute is called `scripts/run_inference.py`. There are many ways to run it, governed by hydra configs.
|
137 |
+
[Hydra configs](https://hydra.cc/docs/configure_hydra/intro/) are a nice way of being able to specify many different options, with sensible defaults drawn *directly* from the model checkpoint, so inference should always, by default, match training.
|
138 |
+
What this means is that the default values in `config/inference/base.yml` might not match the actual values used during inference, with a specific checkpoint. This is all handled under the hood.
|
139 |
+
|
140 |
---
|
141 |
+
### Basic execution - an unconditional monomer
|
142 |
+
<img src="./img/cropped_uncond.png" alt="alt text" width="400px" align="right"/>
|
143 |
+
|
144 |
+
Let's first look at how you would do unconditional design of a protein of length 150aa.
|
145 |
+
For this, we just need to specify three things:
|
146 |
+
1. The length of the protein
|
147 |
+
2. The location where we want to write files to
|
148 |
+
3. The number of designs we want
|
149 |
+
|
150 |
+
```
|
151 |
+
./scripts/run_inference.py 'contigmap.contigs=[150-150]' inference.output_prefix=test_outputs/test inference.num_designs=10
|
152 |
+
```
|
153 |
+
|
154 |
+
Let's look at this in detail.
|
155 |
+
Firstly, what is `contigmap.contigs`?
|
156 |
+
Hydra configs tell the inference script how it should be run. To keep things organised, the config has different sub-configs, one of them being `contigmap`, which pertains to everything related to the contig string (that defines the protein being built).
|
157 |
+
Take a look at the config file if this isn't clear: `configs/inference/base.yml`
|
158 |
+
Anything in the config can be overwritten manually from the command line. You could, for example, change how the diffuser works:
|
159 |
+
```
|
160 |
+
diffuser.crd_scale=0.5
|
161 |
+
```
|
162 |
+
... but don't do this unless you really know what you're doing!!
|
163 |
+
|
164 |
+
|
165 |
+
Now, what does `'contigmap.contigs=[150-150]'` mean?
|
166 |
+
To those who have used RFjoint inpainting, this might look familiar, but a little bit different. Diffusion, in fact, uses the identical 'contig mapper' as inpainting, except that, because we're using hydra, we have to give this to the model in a different way. The contig string has to be passed as a single-item in a list, rather than as a string, for hydra reasons and the entire argument MUST be enclosed in `''` so that the commandline does not attempt to parse any of the special characters.
|
167 |
+
|
168 |
+
The contig string allows you to specify a length range, but here, we just want a protein of 150aa in length, so you just specify [150-150]
|
169 |
+
This will then run 10 diffusion trajectories, saving the outputs to your specified output folder.
|
170 |
+
|
171 |
+
NB the first time you run RFdiffusion, it will take a while 'Calculating IGSO3'. Once it has done this, it'll be cached for future reference though! For an additional example of unconditional monomer generation, take a look at `./examples/design_unconditional.sh` in the repo!
|
172 |
+
|
173 |
+
---
|
174 |
+
### Motif Scaffolding
|
175 |
+
<!--
|
176 |
+
<p align="center">
|
177 |
+
<img src="./img/motif.png" alt="alt text" width="700px" align="middle"/>
|
178 |
+
</p>
|
179 |
+
-->
|
180 |
+
RFdiffusion can be used to scaffold motifs, in a manner akin to [Constrained Hallucination and RFjoint Inpainting](https://www.science.org/doi/10.1126/science.abn2100#:~:text=The%20binding%20and%20catalytic%20functions%20of%20proteins%20are,the%20fold%20or%20secondary%20structure%20of%20the%20scaffold.). In general, RFdiffusion significantly outperforms both Constrained Hallucination and RFjoint Inpainting.
|
181 |
+
<p align="center">
|
182 |
+
<img src="./img/motif.png" alt="alt text" width="700px" align="middle"/>
|
183 |
+
</p>
|
184 |
+
|
185 |
+
When scaffolding protein motifs, we need a way of specifying that we want to scaffold some particular protein input (one or more segments from a `.pdb` file), and to be able to specify how we want these connected, and by how many residues, in the new scaffolded protein. What's more, we want to be able to sample different lengths of connecting protein, as we generally don't know *a priori* precisely how many residues we'll need to best scaffold a motif. This job of specifying inputs is handled by contigs, governed by the contigmap config in the hydra config. For those familiar with Constrained Hallucination or RFjoint Inpainting, the logic is very similar.
|
186 |
+
Briefly:
|
187 |
+
- Anything prefixed by a letter indicates that this is a motif, with the letter corresponding to the chain letter in the input pdb files. E.g. A10-25 pertains to residues ('A',10),('A',11)...('A',25) in the corresponding input pdb
|
188 |
+
- Anything not prefixed by a letter indicates protein *to be built*. This can be input as a length range. These length ranges are randomly sampled each iteration of RFdiffusion inference.
|
189 |
+
- To specify chain breaks, we use `/0 `.
|
190 |
+
|
191 |
+
In more detail, if we want to scaffold a motif, the input is just like RFjoint Inpainting, except needing to navigate the hydra config input. If we want to scaffold residues 10-25 on chain A a pdb, this would be done with `'contigmap.contigs=[5-15/A10-25/30-40]'`. This asks RFdiffusion to build 5-15 residues (randomly sampled at each inference cycle) N-terminally of A10-25 from the input pdb, followed by 30-40 residues (again, randomly sampled) to its C-terminus. If we wanted to ensure the length was always e.g. 55 residues, this can be specified with `contigmap.length=55-55`. You need to obviously also provide a path to your pdb file: `inference.input_pdb=path/to/file.pdb`. It doesn't matter if your input pdb has residues you *don't* want to scaffold - the contig map defines which residues in the pdb are actually used as the "motif". In other words, even if your pdb files has a B chain, and other residues on the A chain, *only* A10-25 will be provided to RFdiffusion.
|
192 |
+
|
193 |
+
To specify that we want to inpaint in the presence of a separate chain, this can be done as follows:
|
194 |
+
|
195 |
+
```
|
196 |
+
'contigmap.contigs=[5-15/A10-25/30-40/0 B1-100]'
|
197 |
+
```
|
198 |
+
Look at this carefully. `/0 ` is the indicator that we want a chain break. NOTE, the space is important here. This tells the diffusion model to add a big residue jump (200aa) to the input, so that the model sees the first chain as being on a separate chain to the second.
|
199 |
+
|
200 |
+
An example of motif scaffolding can be found in `./examples/design_motifscaffolding.sh`.
|
201 |
+
|
202 |
+
### The "active site" model holds very small motifs in place
|
203 |
+
In the RFdiffusion preprint we noted that for very small motifs, RFdiffusion has the tendency to not keep them perfectly fixed in the output. Therefore, for scaffolding minimalist sites such as enzyme active sites, we fine-tuned RFdiffusion on examples similar to these tasks, allowing it to hold smaller motifs better in place, and better generate *in silico* successes. If your input functional motif is very small, we reccomend using this model, which can easily be specified using the following syntax:
|
204 |
+
`inference.ckpt_override_path=models/ActiveSite_ckpt.pt`
|
205 |
+
|
206 |
+
### The `inpaint_seq` flag
|
207 |
+
For those familiar with RFjoint Inpainting, the contigmap.inpaint_seq input is equivalent. The idea is that often, when, for example, fusing two proteins, residues that were on the surface of a protein (and are therefore likely polar), now need to be packed into the 'core' of the protein. We therefore want them to become hydrophobic residues. What we can do, rather than directly mutating them to hydrophobics, is to mask their sequence identity, and allow RFdiffusion to implicitly reason over their sequence, and better pack against them. This requires a different model than the 'base' diffusion model, that has been trained to understand this paradigm, but this is automatically handled by the inference script (you don't need to do anything).
|
208 |
+
|
209 |
+
To specify amino acids whose sequence should be hidden, use the following syntax:
|
210 |
+
```
|
211 |
+
'contigmap.inpaint_seq=[A1/A30-40]'
|
212 |
+
```
|
213 |
+
Here, we're masking the residue identity of residue A1, and all residues between A30 and A40 (inclusive).
|
214 |
+
|
215 |
+
An example of executing motif scaffolding with the `contigmap.inpaint_seq` flag is located in `./examples/design_motifscaffolding_inpaintseq.sh`
|
216 |
+
|
217 |
+
### A note on `diffuser.T`
|
218 |
+
RFdiffusion was originally trained with 200 discrete timesteps. However, recent improvements have allowed us to reduce the number of timesteps we need to use at inference time. In many cases, running with as few as approximately 20 steps provides outputs of equivalent *in silico* quality to running with 200 steps (providing a 10X speedup). The default is now set to 50 steps. Noting this is important for understanding the partial diffusion, described below.
|
219 |
+
|
220 |
+
---
|
221 |
+
### Partial diffusion
|
222 |
+
|
223 |
+
Something we can do with diffusion is to partially noise and de-noise a structure, to get some diversity around a general fold. This can work really nicely (see [Vazquez-Torres et al., BioRxiv 2022](https://www.biorxiv.org/content/10.1101/2022.12.10.519862v4.abstract)).
|
224 |
+
This is specified by using the diffuser.parial_T input, and setting a timestep to 'noise' to.
|
225 |
+
<p align="center">
|
226 |
+
<img src="./img/partial.png" alt="alt text" width="800px" align="middle"/>
|
227 |
+
</p>
|
228 |
+
More noise == more diversity. In Vazquez-Torres et al., 2022, we typically used `diffuser.partial_T` of approximately 80, but this was with respect to the 200 timesteps we were using. Now that the default `diffuser.T` is 50, you will need to adjust diffuser.partial_T accordingly. E.g. now that `diffuser.T=50`, the equivalent of 80 noising steps is `diffuser.partial_T=20`. We strongly recommend sampling different values for `partial_T` however, to find the best parameters for your specific problem.
|
229 |
+
|
230 |
+
When doing partial diffusion, because we are now diffusing from a known structure, this creates certain constraints. You can still use the contig input, but *this has to yield a contig string exactly the same length as the input protein*. E.g. if you have a binder:target complex, and you want to diversify the binder (length 100, chain A), you would need to input something like this:
|
231 |
+
|
232 |
+
```
|
233 |
+
'contigmap.contigs=[100-100/0 B1-150]' diffuser.partial_T=20
|
234 |
+
```
|
235 |
+
The reason for this is that, if your input protein was only 80 amino acids, but you've specified a desired length of 100, we don't know where to diffuse those extra 20 amino acids from, and hence, they will not lie in the distribution that RFdiffusion has learned to denoise from.
|
236 |
+
|
237 |
+
An example of partial diffusion can be found in `./examples/design_partialdiffusion.sh`!
|
238 |
+
|
239 |
+
You can also keep parts of the sequence of the diffused chain fixed, if you want. An example of why you might want to do this is in the context of helical peptide binding. If you've threaded a helical peptide sequence onto an ideal helix, and now want to diversify the complex, allowing the helix to be predicted now not as an ideal helix, you might do something like:
|
240 |
+
|
241 |
+
```
|
242 |
+
'contigmap.contigs=[100-100/0 20-20]' 'contigmap.provide_seq=[100-119]' diffuser.partial_T=10
|
243 |
+
```
|
244 |
+
In this case, the 20aa chain is the helical peptide. The `contigmap.provide_seq` input is zero-indexed, and you can provide a range (so 100-119 is an inclusive range, unmasking the whole sequence of the peptide). Multiple sequence ranges can be provided separated by a comma, e.g. `'contigmap.provide_seq=[172-177,200-205]'`.
|
245 |
+
|
246 |
+
Note that the provide_seq option requires using a different model checkpoint, but this is automatically handled by the inference script.
|
247 |
+
|
248 |
+
An example of partial diffusion with providing sequence in diffused regions can be found in `./examples/design_partialdiffusion_withseq.sh`. The same example specifying multiple sequence ranges can be found in `./examples/design_partialdiffusion_multipleseq.sh`.
|
249 |
+
|
250 |
+
---
|
251 |
+
### Binder Design
|
252 |
+
Hopefully, it's now obvious how you might make a binder with diffusion! Indeed, RFdiffusion shows excellent *in silico* and experimental ability to design *de novo* binders.
|
253 |
+
|
254 |
+
<p align="center">
|
255 |
+
<img src="./img/binder.png" alt="alt text" width="950px" align="middle"/>
|
256 |
+
</p>
|
257 |
+
|
258 |
+
If chain B is your target, then you could do it like this:
|
259 |
+
|
260 |
+
```
|
261 |
+
./scripts/run_inference.py 'contigmap.contigs=[B1-100/0 100-100]' inference.output_prefix=test_outputs/binder_test inference.num_designs=10
|
262 |
+
```
|
263 |
+
|
264 |
+
This will generate 100 residue long binders to residues 1-100 of chain B.
|
265 |
+
|
266 |
+
However, this probably isn't the best way of making binders. Because diffusion is somewhat computationally-intensive, we need to try and make it as fast as possible. Providing the whole of your target, uncropped, is going to make diffusion very slow if your target is big (and most targets-of-interest, such as cell-surface receptors tend to be *very* big). One tried-and-true method to speed up binder design is to crop the target protein around the desired interface location. BUT! This creates a problem: if you crop your target and potentially expose hydrophobic core residues which were buried before the crop, how can you guarantee the binder will go to the intended interface site on the surface of the target, and not target the tantalizing hydrophobic patch you have just artificially created?
|
267 |
+
|
268 |
+
We solve this issue by providing the model with what we call "hotspot residues". The complex models we refer to earlier in this README file have all been trained with hotspot residues, in this training regime, during each example, the model is told (some of) the residues on the target protein which contact the target (i.e., resides that are part of the interface). The model readily learns that it should be making an interface which involved these hotspot residues. At inference time then, we can provide our own hotspot residues to define a region which the binder must contact. These are specified like this: `'ppi.hotspot_res=[A30,A33,A34]'`, where `A` is the chain ID in the input pdb file of the hotspot residue and the number is the residue index in the input pdb file of the hotspot residue.
|
269 |
+
|
270 |
+
Finally, it has been observed that the default RFdiffusion model often generates mostly helical binders. These have high computational and experimental success rates. However, there may be cases where other kinds of topologies may be desired. For this, we include a "beta" model, which generates a greater diversity of topologies, but has not been extensively experimentally validated. Try this at your own risk:
|
271 |
+
|
272 |
+
```
|
273 |
+
inference.ckpt_override_path=models/Complex_beta_ckpt.pt
|
274 |
+
```
|
275 |
+
|
276 |
+
An example of binder design with RFdiffusion can be found in `./examples/design_ppi.sh`.
|
277 |
+
|
278 |
+
---
|
279 |
+
|
280 |
+
## Practical Considerations for Binder Design
|
281 |
+
|
282 |
+
RFdiffusion is an extremely powerful binder design tool but it is not magic. In this section we will walk through some common pitfalls in RFdiffusion binder design and offer advice on how to get the most out of this method.
|
283 |
+
|
284 |
+
### Selecting a Target Site
|
285 |
+
Not every site on a target protein is a good candidate for binder design. For a site to be an attractive candidate for binding it should have >~3 hydrophobic residues for the binder to interact with. Binding to charged polar sites is still quite hard. Binding to sites with glycans close to them is also hard since they often become ordered upon binding and you will take an energetic hit for that. Historically, binder design has also avoided unstructured loops, it is not clear if this is still a requirement as RFdiffusion has been used to bind unstructured peptides which share a lot in common with unstructured loops.
|
286 |
+
|
287 |
+
### Truncating your Target Protein
|
288 |
+
RFdiffusion scales in runtime as O(N^2) where N is the number of residues in your system. As such, it is a very good idea to truncate large targets so that your computations are not unnecessarily expensive. RFdiffusion and all downstream steps (including AF2) are designed to allow for a truncated target. Truncating a target is an art. For some targets, such as multidomain extracellular membranes, a natural truncation point is where two domains are joined by a flexible linker. For other proteins, such as virus spike proteins, this truncation point is less obvious. Generally you want to preserve secondary structure and introduce as few chain breaks as possible. You should also try to leave ~10A of target protein on each side of your intended target site. We recommend using PyMol to truncate your target protein.
|
289 |
+
|
290 |
+
### Picking Hotspots
|
291 |
+
Hotspots are a feature that we integrated into the model to allow for the control of the site on the target which the binder will interact with. In the paper we define a hotspot as a residue on the target protein which is within 10A Cbeta distance of the binder. Of all of the hotspots which are identified on the target 0-20% of these hotspots are actually provided to the model and the rest are masked. This is important for understanding how you should pick hotspots at inference time.; the model is expecting to have to make more contacts than you specify. We normally recommend between 3-6 hotspots, you should run a few pilot runs before generating thousands of designs to make sure the number of hotspots you are providing will give results you like.
|
292 |
+
|
293 |
+
If you have run the previous PatchDock RifDock binder design pipeline, for the RFdiffusion paper we chose our hotspots to be the PatchDock residues of the target.
|
294 |
+
|
295 |
+
### Binder Design Scale
|
296 |
+
In the paper, we generated ~10,000 RFdiffusion binder backbones for each target. From this set of backbones we then generated two sequences per backbone using ProteinMPNN-FastRelax (described below). We screened these ~20,000 designs using AF2 with initial guess and target templating (also described below).
|
297 |
+
|
298 |
+
Given the high success rates we observed in the paper, for some targets it may be sufficient to only generate ~1,000 RFdiffusion backbones in a campaign. What you want is to get enough designs that pass pAE_interaction < 10 (described more in Binder Design Filtering section) such that you are able to fill a DNA order with these successful designs. We have found that designs that do not pass pAE_interaction < 10 are not worth ordering since they will likely not work experimentally.
|
299 |
+
|
300 |
+
### Sequence Design for Binders
|
301 |
+
You may have noticed that the binders designed by RFdiffusion come out with a poly-Glycine sequence. This is not a bug. RFdiffusion is a backbone-generation model and does not generate sequence for the designed region, therefore, another method must be used to assign a sequence to the binders. In the paper we use the ProteinMPNN-FastRelax protocol to do sequence design. We recommend that you do this as well. The code for this protocol can be found in [this GitHub repo](https://github.com/nrbennet/dl_binder_design). While we did not find the FastRelax part of the protocol to yield the large in silico success rate improvements that it yielded with the RifDock-generated docks, it is still a good way to increase your number of shots-on-goal for each (computationally expensive) RFdiffusion backbone. If you would prefer to simply run ProteinMPNN on your binders without the FastRelax step, that will work fine but will be more computationally expensive.
|
302 |
+
|
303 |
+
### Binder Design Filtering
|
304 |
+
One of the most important parts of the binder design pipeline is a filtering step to evaluate if your binders are actually predicted to work. In the paper we filtered using AF2 with an initial guess and target templating, scripts for this protocol are available [here](https://github.com/nrbennet/dl_binder_design). We have found that filtering at pae_interaction < 10 is a good predictor of a binder working experimentally.
|
305 |
+
|
306 |
+
---
|
307 |
+
|
308 |
+
### Fold Conditioning
|
309 |
+
Something that works really well is conditioning binder design (or monomer generation) on particular topologies. This is achieved by providing (partial) secondary structure and block adjacency information (to a model that has been trained to condition on this).
|
310 |
+
<p align="center">
|
311 |
+
<img src="./img/fold_cond.png" alt="alt text" width="950px" align="middle"/>
|
312 |
+
</p>
|
313 |
+
We are still working out the best way to actually generate this input at inference time, but for now, we have settled upon generating inputs directly from pdb structures. This permits 'low resolution' specification of output topology (i.e., I want a TIM barrel but I don't care precisely where resides are). In `helper_scripts/`, there's a script called `make_secstruc_adj.py`, which can be used as follows:
|
314 |
+
|
315 |
+
e.g. 1:
|
316 |
+
```
|
317 |
+
./make_secstruc_adj.py --input_pdb ./2KL8.pdb --out_dir /my/dir/for/adj_secstruct
|
318 |
+
```
|
319 |
+
or e.g. 2:
|
320 |
+
```
|
321 |
+
./make_secstruc_adj.py --pdb_dir ./pdbs/ --out_dir /my/dir/for/adj_secstruct
|
322 |
+
```
|
323 |
+
|
324 |
+
This will process either a single pdb, or a folder of pdbs, and output a secondary structure and adjacency pytorch file, ready to go into the model. For now (although this might not be necessary), you should also generate these files for the target protein (if you're doing PPI), and provide this to the model. You can then use these at inference as follows:
|
325 |
+
|
326 |
+
```
|
327 |
+
./scripts/run_inference.py inference.output_prefix=./scaffold_conditioned_test/test scaffoldguided.scaffoldguided=True scaffoldguided.target_pdb=False scaffoldguided.scaffold_dir=./examples/ppi_scaffolds_subset
|
328 |
+
```
|
329 |
+
|
330 |
+
A few extra things:
|
331 |
+
1) As mentioned above, for PPI, you will want to provide a target protein, along with its secondary structure and block adjacency. This can be done by adding:
|
332 |
+
|
333 |
+
```
|
334 |
+
scaffoldguided.target_pdb=True scaffoldguided.target_path=input_pdbs/insulin_target.pdb inference.output_prefix=insulin_binder/jordi_ss_insulin_noise0_job0 'ppi.hotspot_res=[A59,A83,A91]' scaffoldguided.target_ss=target_folds/insulin_target_ss.pt scaffoldguided.target_adj=target_folds/insulin_target_adj.pt
|
335 |
+
```
|
336 |
+
|
337 |
+
To generate these block adjacency and secondary structure inputs, you can use the helper script.
|
338 |
+
|
339 |
+
This will now generate 3-helix bundles to the insulin target.
|
340 |
+
|
341 |
+
For ppi, it's probably also worth adding this flag:
|
342 |
+
|
343 |
+
```
|
344 |
+
scaffoldguided.mask_loops=False
|
345 |
+
```
|
346 |
+
|
347 |
+
This is quite important to understand. During training, we mask some of the secondary structure and block adjacency. This is convenient, because it allows us to, at inference, easily add extra residues without having to specify precise secondary structure for every residue. E.g. if you want to make a long 3 helix bundle, you could mask the loops, and add e.g. 20 more 'mask' tokens to that loop. The model will then (presumbly) choose to make e.g. 15 of these residues into helices (to extend the 3HB), and then make a 5aa loop. But, you didn't have to specify that, which is nice. The way this would be done would be like this:
|
348 |
+
|
349 |
+
```
|
350 |
+
scaffoldguided.mask_loops=True scaffoldguided.sampled_insertion=15 scaffoldguided.sampled_N=5 scaffoldguided.sampled_C=5
|
351 |
+
```
|
352 |
+
|
353 |
+
This will, at each run of inference, sample up to 15 residues to insert into loops in your 3HB input, and up to 5 additional residues at N and C terminus.
|
354 |
+
This strategy is very useful if you don't have a large set of pdbs to make block adjacencies for. For example, we showed that we could generate loads of lengthened TIM barrels from a single starting pdb with this strategy. However, for PPI, if you're using the provided scaffold sets, it shouldn't be necessary (because there are so many scaffolds to start from, generating extra diversity isn't especially necessary).
|
355 |
+
|
356 |
+
Finally, if you have a big directory of block adjacency/secondary structure files, but don't want to use all of them, you can make a `.txt` file of the ones you want to use, and pass:
|
357 |
+
|
358 |
+
```
|
359 |
+
scaffoldguided.scaffold_list=path/to/list
|
360 |
+
```
|
361 |
+
|
362 |
+
For PPI, we've consistently seen that reducing the noise added at inference improves designs. This comes at the expense of diversity, but, given that the scaffold sets are huge, this probably doesn't matter too much. We therefore recommend lowering the noise. 0.5 is probably a good compromise:
|
363 |
+
|
364 |
+
```
|
365 |
+
denoiser.noise_scale_ca=0.5 denoiser.noise_scale_frame=0.5
|
366 |
+
```
|
367 |
+
This just scales the amount of noise we add to the translations (`noise_scale_ca`) and rotations (`noise_scale_frame`) by, in this case, 0.5.
|
368 |
+
|
369 |
+
An additional example of PPI with fold conditioning is available here: `./examples/design_ppi_scaffolded.sh`
|
370 |
+
|
371 |
+
---
|
372 |
+
|
373 |
+
### Generation of Symmetric Oligomers
|
374 |
+
We're going to switch gears from discussing PPI and look at another task at which RFdiffusion performs well on: symmetric oligomer design. This is done by symmetrising the noise we sample at t=T, and symmetrising the input at every timestep. We have currently implemented the following for use (with the others coming soon!):
|
375 |
+
- Cyclic symmetry
|
376 |
+
- Dihedral symmetry
|
377 |
+
- Tetrahedral symmetry
|
378 |
+
|
379 |
+
<p align="center">
|
380 |
+
<img src="./img/olig2.png" alt="alt text" width="1000px" align="middle"/>
|
381 |
+
</p>
|
382 |
+
|
383 |
+
Here's an example:
|
384 |
+
```
|
385 |
+
./scripts/run_inference.py --config-name symmetry inference.symmetry=tetrahedral 'contigmap.contigs=[360]' inference.output_prefix=test_sample/tetrahedral inference.num_designs=1
|
386 |
+
```
|
387 |
+
|
388 |
+
Here, we've specified a different `config` file (with `--config-name symmetry`). Because symmetric diffusion is quite different from the diffusion described above, we packaged a whole load of symmetry-related configs into a new file (see `configs/inference/symmetry.yml`). Using this config file now puts diffusion in `symmetry-mode`.
|
389 |
+
|
390 |
+
The symmetry type is then specified with `inference.symmetry=`. Here, we're specifying tetrahedral symmetry, but you could also choose cyclic (e.g. `c4`) or dihedral (e.g. `d2`).
|
391 |
+
|
392 |
+
The configmap.contigs length refers to the *total* length of your oligomer. Therefore, it *must* be divisible by *n* chains.
|
393 |
+
|
394 |
+
More examples of designing oligomers can be found here: `./examples/design_cyclic_oligos.sh`, `./examples/design_dihedral_oligos.sh`, `./examples/design_tetrahedral_oligos.sh`.
|
395 |
+
|
396 |
+
---
|
397 |
+
|
398 |
+
### Using Auxiliary Potentials
|
399 |
+
Performing diffusion with symmetrized noise may give you the idea that we could use other external interventions during the denoising process to guide diffusion. One such intervention that we have implemented is auxiliary potentials. Auxiliary potentials can be very useful for guiding the inference process. E.g. whereas in RFjoint inpainting, we have little/no control over the final shape of an output, in diffusion we can readily force the network to make, for example, a well-packed protein.
|
400 |
+
This is achieved in the updates we make at each step.
|
401 |
+
|
402 |
+
Let's go a little deeper into how the diffusion process works:
|
403 |
+
At timestep T (the first step of the reverse-diffusion inference process), we sample noise from a known *prior* distribution. The model then makes a prediction of what the final structure should be, and we use these two states (noise at time T, prediction of the structure at time 0) to back-calculate where t=T-1 would have been. We therefore have a vector pointing from each coordinate at time T, to their corresponding, back-calculated position at time T-1.
|
404 |
+
But, we want to be able to bias this update, to *push* the trajectory towards some desired state. This can be done by biasing that vector with another vector, which points towards a position where that residue would *reduce* the 'loss' as defined by your potential. E.g. if we want to use the `monomer_ROG` potential, which seeks to minimise the radius of gyration of the final protein, if the models prediction of t=0 is very elongated, each of those distant residues will have a larger gradient when we differentiate the `monomer_ROG` potential w.r.t. their positions. These gradients, along with the corresponding scale, can be combined into a vector, which is then combined with the original update vector to make a "biased update" at that timestep.
|
405 |
+
|
406 |
+
The exact parameters used when applying these potentials matter. If you weight them too strongly, you're not going to end up with a good protein. Too weak, and they'll have little effect. We've explored these potentials in a few different scenarios, and have set sensible defaults, if you want to use them. But, if you feel like they're too weak/strong, or you just fancy exploring, do play with the parameters (in the `potentials` part of the config file).
|
407 |
+
|
408 |
+
Potentials are specified as a list of strings with each string corresponding to a potential. The argument for potentials is `potentials.guiding_potentials`. Within the string per-potential arguments may be specified in the following syntax: `arg_name1:arg_value1,arg_name2:arg_value2,...,arg_nameN:arg_valueN`. The only argument that is required for each potential is the name of the potential that you wish to apply, the name of this argument is `type` as-in the type of potential you wish to use. Some potentials such as `olig_contacts` and `substrate_contacts` take global options such as `potentials.substrate`, see `config/inference/base.yml` for all the global arguments associated with potentials. Additionally, it is useful to have the effect of the potential "decay" throughout the trajectory, such that in the beginning the effect of the potential is 1x strength, and by the end is much weaker. These decays (`constant`,`linear`,`quadratic`,`cubic`) can be set with the `potentials.guide_decay` argument.
|
409 |
+
|
410 |
+
Here's an example of how to specify a potential:
|
411 |
+
|
412 |
+
```
|
413 |
+
potentials.guiding_potentials=[\"type:olig_contacts,weight_intra:1,weight_inter:0.1\"] potentials.olig_intra_all=True potentials.olig_inter_all=True potentials.guide_scale=2 potentials.guide_decay='quadratic'
|
414 |
+
```
|
415 |
+
|
416 |
+
We are still fully characterising how/when to use potentials, and we strongly recommend exploring different parameters yourself, as they are clearly somewhat case-dependent. So far, it is clear that they can be helpful for motif scaffolding and symmetric oligomer generation. However, they seem to interact weirdly with hotspot residues in PPI. We think we know why this is, and will work in the coming months to write better potentials for PPI. And please note, it is often good practice to start with *no potentials* as a baseline, then slowly increase their strength. For the oligomer contacts potentials, start with the ones provided in the examples, and note that the `intra` chain potential often should be higher than the `inter` chain potential.
|
417 |
+
|
418 |
+
We have already implemented several potentials but it is relatively straightforward to add more, if you want to push your designs towards some specified goal. The *only* condition is that, whatever potential you write, it is differentiable. Take a look at `potentials.potentials.py` for examples of the potentials we have implemented so far.
|
419 |
+
|
420 |
+
---
|
421 |
+
|
422 |
+
### Symmetric Motif Scaffolding.
|
423 |
+
We can also combine symmetric diffusion with motif scaffolding to scaffold motifs symmetrically.
|
424 |
+
Currently, we have one way for performing symmetric motif scaffolding. That is by specifying the position of the motif specified w.r.t. the symmetry axes.
|
425 |
+
|
426 |
+
<p align="center">
|
427 |
+
<img src="./img/sym_motif.png" alt="alt text" width="1000px" align="middle"/>
|
428 |
+
</p>
|
429 |
+
|
430 |
+
**Special input .pdb and contigs requirements**
|
431 |
+
|
432 |
+
For now, we require that a user have a symmetrized version of their motif in their input pdb for symmetric motif scaffolding. There are two main reasons for this. First, the model is trained by centering any motif at the origin, and thus the code also centers motifs at the origin automatically. Therefore, if your motif is not symmetrized, this centering action will result in an asymmetric unit that now has the origin and axes of symmetry running right through it (bad). Secondly, the diffusion code uses a canonical set of symmetry axes (rotation matrices) to propogate the asymmetric unit of a motif. In order to prevent accidentally running diffusion trajectories which are propogating your motif in ways you don't intend, we require that a user symmetrize an input using the RFdiffusion canonical symmetry axes.
|
433 |
+
|
434 |
+
**RFdiffusion canonical symmetry axes**
|
435 |
+
| Group | Axis |
|
436 |
+
|:----------:|:-------------:|
|
437 |
+
| Cyclic | Z |
|
438 |
+
| Dihedral (cyclic) | Z |
|
439 |
+
| Dihedral (flip/reflection) | X |
|
440 |
+
|
441 |
+
|
442 |
+
**Example: Inputs for symmetric motif scaffolding with motif position specified w.r.t the symmetry axes.**
|
443 |
+
|
444 |
+
This example script `examples/design_nickel.sh` can be used to scaffold the C4 symmetric Nickel binding domains shown in the RFdiffusion paper. It combines many concepts discussed earlier, including symmetric oligomer generation, motif scaffolding, and use of guiding potentials.
|
445 |
+
|
446 |
+
Note that the contigs should specify something that is precisely symmetric. Things will break if this is not the case.
|
447 |
+
|
448 |
+
---
|
449 |
+
|
450 |
+
### A Note on Model Weights
|
451 |
+
|
452 |
+
Because of everything we want diffusion to be able to do, there is not *One Model To Rule Them All*. E.g., if you want to run with secondary structure conditioning, this requires a different model than if you don't. Under the hood, we take care of most of this by default - we parse your input and work out the most appropriate checkpoint.
|
453 |
+
This is where the config setup is really useful. The exact model checkpoint used at inference contains in it all of the parameters is was trained with, so we can just populate the config file with those values, such that inference runs as designed.
|
454 |
+
If you do want to specify a different checkpoint (if, for example, we train a new model and you want to test it), you just have to make sure it's compatible with what you're doing. E.g. if you try and give secondary structure features to a model that wasn't trained with them, it'll crash.
|
455 |
+
|
456 |
+
### Things you might want to play with at inference time
|
457 |
+
|
458 |
+
Occasionally, it might good to try an alternative model (for example the active site model, or the beta binder model). These can be specified with `inference.ckpt_override_path`. We do not recommend using these outside of the described use cases, however, as there is not a guarantee they will understand other kinds of inputs.
|
459 |
+
|
460 |
+
For a full list of things that are implemented at inference, see the config file (`configs/inference/base.yml` or `configs/inference/symmetry.yml`). Although you can modify everything, this is not recommended unless you know what you're doing.
|
461 |
+
Generally, don't change the `model`, `preprocess` or `diffuser` configs. These pertain to how the model was trained, so it's unwise to change how you use the model at inference time.
|
462 |
+
However, the parameters below are definitely worth exploring:
|
463 |
+
-inference.final_step: This is when we stop the trajectory. We have seen that you can stop early, and the model is already making a good prediction of the final structure. This speeds up inference.
|
464 |
+
-denoiser.noise_scale_ca and denoiser.noise_scale_frame: These can be used to reduce the noise used during sampling (as discussed for PPI above). The default is 1 (the same noise added at training), but this can be reduced to e.g. 0.5, or even 0. This actually improves the quality of models coming out of diffusion, but at the expense of diversity. If you're not getting any good outputs, or if your problem is very constrained, you could try reducing the noise. While these parameters can be changed independently (for translations and rotations), we recommend keeping them tied.
|
465 |
+
|
466 |
+
### Understanding the output files
|
467 |
+
We output several different files.
|
468 |
+
1. The `.pdb` file. This is the final prediction out of the model. Note that every designed residue is output as a glycine (as we only designed the backbone), and no sidechains are output. This is because, even though RFdiffusion conditions on sidechains in an input motif, there is no loss applied to these predictions, so they can't strictly be trusted.
|
469 |
+
2. The `.trb` file. This contains useful metadata associated with that specific run, including the specific contig used (if length ranges were sampled), as well as the full config used by RFdiffusion. There are also a few other convenient items in this file:
|
470 |
+
- details about mapping (i.e. how residues in the input map to residues in the output)
|
471 |
+
- `con_ref_pdb_idx`/`con_hal_pdb_idx` - These are two arrays including the input pdb indices (in con_ref_pdb_idx), and where they are in the output pdb (in con_hal_pdb_idx). This only contains the chains where inpainting took place (i.e. not any fixed receptor/target chains)
|
472 |
+
- `con_ref_idx0`/`con_hal_idx0` - These are the same as above, but 0 indexed, and without chain information. This is useful for splicing coordinates out (to assess alignment etc).
|
473 |
+
- `inpaint_seq` - This details any residues that were masked during inference.
|
474 |
+
3. Trajectory files. By default, we output the full trajectories into the `/traj/` folder. These files can be opened in pymol, as multi-step pdbs. Note that these are ordered in reverse, so the first pdb is technically the last (t=1) prediction made by RFdiffusion during inference. We include both the `pX0` predictions (what the model predicted at each timestep) and the `Xt-1` trajectories (what went into the model at each timestep).
|
475 |
+
|
476 |
+
### Docker
|
477 |
+
|
478 |
+
We have provided a Dockerfile at `docker/Dockerfile` to help run RFDiffusion on HPC and other container orchestration systems. Follow these steps to build and run the container on your system:
|
479 |
+
|
480 |
+
1. Clone this repository with `git clone https://github.com/RosettaCommons/RFdiffusion.git` and then `cd RFdiffusion`
|
481 |
+
1. Verify that the Docker daemon is running on your system with `docker info`. You can find Docker installation instructions for Mac, WIndows, and Linux in the [official Docker docs](https://docs.docker.com/get-docker/). You may also consider [Finch](https://github.com/runfinch/finch), the open source client for container development.
|
482 |
+
1. Build the container image on your system with `docker build -f docker/Dockerfile -t rfdiffusion .`
|
483 |
+
1. Create some folders on your file system with `mkdir $HOME/inputs $HOME/outputs $HOME/models`
|
484 |
+
1. Download the RFDiffusion models with `bash scripts/download_models.sh $HOME/models`
|
485 |
+
1. Download a test file (or another of your choice) with `wget -P $HOME/inputs https://files.rcsb.org/view/5TPN.pdb`
|
486 |
+
1. Run the container with the following command:
|
487 |
+
|
488 |
+
```bash
|
489 |
+
docker run -it --rm --gpus all \
|
490 |
+
-v $HOME/models:$HOME/models \
|
491 |
+
-v $HOME/inputs:$HOME/inputs \
|
492 |
+
-v $HOME/outputs:$HOME/outputs \
|
493 |
+
rfdiffusion \
|
494 |
+
inference.output_prefix=$HOME/outputs/motifscaffolding \
|
495 |
+
inference.model_directory_path=$HOME/models \
|
496 |
+
inference.input_pdb=$HOME/inputs/5TPN.pdb \
|
497 |
+
inference.num_designs=3 \
|
498 |
+
'contigmap.contigs=[10-40/A163-181/10-40]'
|
499 |
+
```
|
500 |
+
|
501 |
+
This starts the `rfdiffusion` container, mounts the models, inputs, and outputs folders, passes all available GPUs, and then calls the `run_inference.py` script with the parameters specified.
|
502 |
+
|
503 |
+
### Conclusion
|
504 |
+
|
505 |
+
We are extremely excited to share RFdiffusion with the wider scientific community. We expect to push some updates as and when we make sizeable improvements in the coming months, so do stay tuned. We realize it may take some time to get used to executing RFdiffusion with perfect syntax (sometimes Hydra is hard), so please don't hesitate to create GitHub issues if you need help, we will respond as often as we can.
|
506 |
+
|
507 |
+
Now, let's go make some proteins. Have fun!
|
508 |
+
|
509 |
+
\- Joe, David, Nate, Brian, Jason, and the RFdiffusion team.
|
510 |
+
|
511 |
+
---
|
512 |
+
|
513 |
+
RFdiffusion builds directly on the architecture and trained parameters of RoseTTAFold. We therefore thank Frank DiMaio and Minkyung Baek, who developed RoseTTAFold.
|
514 |
+
RFdiffusion is released under an open source BSD License (see LICENSE file). It is free for both non-profit and for-profit use.
|
515 |
+
|
516 |
+
|
appverifUI.dll
ADDED
Binary file (112 kB). View file
|
|
config/inference/base.yaml
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Base inference Configuration.
|
2 |
+
|
3 |
+
inference:
|
4 |
+
input_pdb: null
|
5 |
+
num_designs: 10
|
6 |
+
design_startnum: 0
|
7 |
+
ckpt_override_path: null
|
8 |
+
symmetry: null
|
9 |
+
recenter: True
|
10 |
+
radius: 10.0
|
11 |
+
model_only_neighbors: False
|
12 |
+
output_prefix: samples/design
|
13 |
+
write_trajectory: True
|
14 |
+
scaffold_guided: False
|
15 |
+
model_runner: SelfConditioning
|
16 |
+
cautious: True
|
17 |
+
align_motif: True
|
18 |
+
symmetric_self_cond: True
|
19 |
+
final_step: 1
|
20 |
+
deterministic: False
|
21 |
+
trb_save_ckpt_path: null
|
22 |
+
schedule_directory_path: null
|
23 |
+
model_directory_path: null
|
24 |
+
|
25 |
+
contigmap:
|
26 |
+
contigs: null
|
27 |
+
inpaint_seq: null
|
28 |
+
provide_seq: null
|
29 |
+
length: null
|
30 |
+
|
31 |
+
model:
|
32 |
+
n_extra_block: 4
|
33 |
+
n_main_block: 32
|
34 |
+
n_ref_block: 4
|
35 |
+
d_msa: 256
|
36 |
+
d_msa_full: 64
|
37 |
+
d_pair: 128
|
38 |
+
d_templ: 64
|
39 |
+
n_head_msa: 8
|
40 |
+
n_head_pair: 4
|
41 |
+
n_head_templ: 4
|
42 |
+
d_hidden: 32
|
43 |
+
d_hidden_templ: 32
|
44 |
+
p_drop: 0.15
|
45 |
+
SE3_param_full:
|
46 |
+
num_layers: 1
|
47 |
+
num_channels: 32
|
48 |
+
num_degrees: 2
|
49 |
+
n_heads: 4
|
50 |
+
div: 4
|
51 |
+
l0_in_features: 8
|
52 |
+
l0_out_features: 8
|
53 |
+
l1_in_features: 3
|
54 |
+
l1_out_features: 2
|
55 |
+
num_edge_features: 32
|
56 |
+
SE3_param_topk:
|
57 |
+
num_layers: 1
|
58 |
+
num_channels: 32
|
59 |
+
num_degrees: 2
|
60 |
+
n_heads: 4
|
61 |
+
div: 4
|
62 |
+
l0_in_features: 64
|
63 |
+
l0_out_features: 64
|
64 |
+
l1_in_features: 3
|
65 |
+
l1_out_features: 2
|
66 |
+
num_edge_features: 64
|
67 |
+
freeze_track_motif: False
|
68 |
+
use_motif_timestep: False
|
69 |
+
|
70 |
+
diffuser:
|
71 |
+
T: 50
|
72 |
+
b_0: 1e-2
|
73 |
+
b_T: 7e-2
|
74 |
+
schedule_type: linear
|
75 |
+
so3_type: igso3
|
76 |
+
crd_scale: 0.25
|
77 |
+
partial_T: null
|
78 |
+
so3_schedule_type: linear
|
79 |
+
min_b: 1.5
|
80 |
+
max_b: 2.5
|
81 |
+
min_sigma: 0.02
|
82 |
+
max_sigma: 1.5
|
83 |
+
|
84 |
+
denoiser:
|
85 |
+
noise_scale_ca: 1
|
86 |
+
final_noise_scale_ca: 1
|
87 |
+
ca_noise_schedule_type: constant
|
88 |
+
noise_scale_frame: 1
|
89 |
+
final_noise_scale_frame: 1
|
90 |
+
frame_noise_schedule_type: constant
|
91 |
+
|
92 |
+
ppi:
|
93 |
+
hotspot_res: null
|
94 |
+
|
95 |
+
potentials:
|
96 |
+
guiding_potentials: null
|
97 |
+
guide_scale: 10
|
98 |
+
guide_decay: constant
|
99 |
+
olig_inter_all : null
|
100 |
+
olig_intra_all : null
|
101 |
+
olig_custom_contact : null
|
102 |
+
substrate: null
|
103 |
+
|
104 |
+
contig_settings:
|
105 |
+
ref_idx: null
|
106 |
+
hal_idx: null
|
107 |
+
idx_rf: null
|
108 |
+
inpaint_seq_tensor: null
|
109 |
+
|
110 |
+
preprocess:
|
111 |
+
sidechain_input: False
|
112 |
+
motif_sidechain_input: True
|
113 |
+
d_t1d: 22
|
114 |
+
d_t2d: 44
|
115 |
+
prob_self_cond: 0.0
|
116 |
+
str_self_cond: False
|
117 |
+
predict_previous: False
|
118 |
+
|
119 |
+
logging:
|
120 |
+
inputs: False
|
121 |
+
|
122 |
+
scaffoldguided:
|
123 |
+
scaffoldguided: False
|
124 |
+
target_pdb: False
|
125 |
+
target_path: null
|
126 |
+
scaffold_list: null
|
127 |
+
scaffold_dir: null
|
128 |
+
sampled_insertion: 0
|
129 |
+
sampled_N: 0
|
130 |
+
sampled_C: 0
|
131 |
+
ss_mask: 0
|
132 |
+
systematic: False
|
133 |
+
target_ss: null
|
134 |
+
target_adj: null
|
135 |
+
mask_loops: True
|
136 |
+
contig_crop: null
|
config/inference/symmetry.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Config for sampling symmetric assemblies.
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- base
|
5 |
+
|
6 |
+
inference:
|
7 |
+
# Symmetry to sample
|
8 |
+
# Available symmetries:
|
9 |
+
# - Cyclic symmetry (C_n) # call as c5
|
10 |
+
# - Dihedral symmetry (D_n) # call as d5
|
11 |
+
# - Tetrahedral symmetry # call as tetrahedral
|
12 |
+
# - Octahedral symmetry # call as octahedral
|
13 |
+
# - Icosahedral symmetry # call as icosahedral
|
14 |
+
symmetry: c2
|
15 |
+
|
16 |
+
# Set to true for computational efficiency
|
17 |
+
# to avoid memory overhead of modeling all subunits.
|
18 |
+
model_only_neighbors: False
|
19 |
+
|
20 |
+
# Output directory of samples.
|
21 |
+
output_prefix: samples/c2
|
22 |
+
|
23 |
+
contigmap:
|
24 |
+
# Specify a single integer value to sample unconditionally.
|
25 |
+
# Must be evenly divisible by the number of chains in the symmetry.
|
26 |
+
contigs: ['100']
|
docker/Dockerfile
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Usage:
|
2 |
+
# git clone https://github.com/RosettaCommons/RFdiffusion.git
|
3 |
+
# cd RFdiffusion
|
4 |
+
# docker build -f docker/Dockerfile -t rfdiffusion .
|
5 |
+
# mkdir $HOME/inputs $HOME/outputs $HOME/models
|
6 |
+
# bash scripts/download_models.sh $HOME/models
|
7 |
+
# wget -P $HOME/inputs https://files.rcsb.org/view/5TPN.pdb
|
8 |
+
|
9 |
+
# docker run -it --rm --gpus all \
|
10 |
+
# -v $HOME/models:$HOME/models \
|
11 |
+
# -v $HOME/inputs:$HOME/inputs \
|
12 |
+
# -v $HOME/outputs:$HOME/outputs \
|
13 |
+
# rfdiffusion \
|
14 |
+
# inference.output_prefix=$HOME/outputs/motifscaffolding \
|
15 |
+
# inference.model_directory_path=$HOME/models \
|
16 |
+
# inference.input_pdb=$HOME/inputs/5TPN.pdb \
|
17 |
+
# inference.num_designs=3 \
|
18 |
+
# 'contigmap.contigs=[10-40/A163-181/10-40]'
|
19 |
+
|
20 |
+
FROM nvcr.io/nvidia/cuda:11.6.2-cudnn8-runtime-ubuntu20.04
|
21 |
+
|
22 |
+
COPY . /app/RFdiffusion/
|
23 |
+
|
24 |
+
RUN apt-get -q update \
|
25 |
+
&& DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
|
26 |
+
git \
|
27 |
+
python3.9 \
|
28 |
+
python3-pip \
|
29 |
+
&& python3.9 -m pip install -q -U --no-cache-dir pip \
|
30 |
+
&& rm -rf /var/lib/apt/lists/* \
|
31 |
+
&& apt-get autoremove -y \
|
32 |
+
&& apt-get clean \
|
33 |
+
&& pip install -q --no-cache-dir \
|
34 |
+
dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html \
|
35 |
+
torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 \
|
36 |
+
e3nn==0.3.3 \
|
37 |
+
wandb==0.12.0 \
|
38 |
+
pynvml==11.0.0 \
|
39 |
+
git+https://github.com/NVIDIA/dllogger#egg=dllogger \
|
40 |
+
decorator==5.1.0 \
|
41 |
+
hydra-core==1.3.2 \
|
42 |
+
pyrsistent==0.19.3 \
|
43 |
+
/app/RFdiffusion/env/SE3Transformer \
|
44 |
+
&& pip install --no-cache-dir /app/RFdiffusion --no-deps
|
45 |
+
|
46 |
+
WORKDIR /app/RFdiffusion
|
47 |
+
|
48 |
+
ENV DGLBACKEND="pytorch"
|
49 |
+
|
50 |
+
ENTRYPOINT ["python3.9", "scripts/run_inference.py"]
|
env/SE3Transformer/.dockerignore
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.Trash-0
|
2 |
+
.git
|
3 |
+
data/
|
4 |
+
.DS_Store
|
5 |
+
*wandb/
|
6 |
+
*.pt
|
7 |
+
*.swp
|
8 |
+
|
9 |
+
# added by FAFU
|
10 |
+
.idea/
|
11 |
+
cache/
|
12 |
+
downloaded/
|
13 |
+
*.lprof
|
14 |
+
|
15 |
+
# Byte-compiled / optimized / DLL files
|
16 |
+
__pycache__/
|
17 |
+
*.py[cod]
|
18 |
+
*$py.class
|
19 |
+
|
20 |
+
# C extensions
|
21 |
+
*.so
|
22 |
+
|
23 |
+
# Distribution / packaging
|
24 |
+
.Python
|
25 |
+
build/
|
26 |
+
develop-eggs/
|
27 |
+
dist/
|
28 |
+
downloads/
|
29 |
+
eggs/
|
30 |
+
.eggs/
|
31 |
+
lib/
|
32 |
+
lib64/
|
33 |
+
parts/
|
34 |
+
sdist/
|
35 |
+
var/
|
36 |
+
wheels/
|
37 |
+
*.egg-info/
|
38 |
+
.installed.cfg
|
39 |
+
*.egg
|
40 |
+
MANIFEST
|
41 |
+
|
42 |
+
# PyInstaller
|
43 |
+
# Usually these files are written by a python script from a template
|
44 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
45 |
+
*.manifest
|
46 |
+
*.spec
|
47 |
+
|
48 |
+
# Installer logs
|
49 |
+
pip-log.txt
|
50 |
+
pip-delete-this-directory.txt
|
51 |
+
|
52 |
+
# Unit test / coverage reports
|
53 |
+
htmlcov/
|
54 |
+
.tox/
|
55 |
+
.coverage
|
56 |
+
.coverage.*
|
57 |
+
.cache
|
58 |
+
nosetests.xml
|
59 |
+
coverage.xml
|
60 |
+
*.cover
|
61 |
+
.hypothesis/
|
62 |
+
.pytest_cache/
|
63 |
+
|
64 |
+
# Translations
|
65 |
+
*.mo
|
66 |
+
*.pot
|
67 |
+
|
68 |
+
# Django stuff:
|
69 |
+
*.log
|
70 |
+
local_settings.py
|
71 |
+
db.sqlite3
|
72 |
+
|
73 |
+
# Flask stuff:
|
74 |
+
instance/
|
75 |
+
.webassets-cache
|
76 |
+
|
77 |
+
# Scrapy stuff:
|
78 |
+
.scrapy
|
79 |
+
|
80 |
+
# Sphinx documentation
|
81 |
+
docs/_build/
|
82 |
+
|
83 |
+
# PyBuilder
|
84 |
+
target/
|
85 |
+
|
86 |
+
# Jupyter Notebook
|
87 |
+
.ipynb_checkpoints
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
.python-version
|
91 |
+
|
92 |
+
# celery beat schedule file
|
93 |
+
celerybeat-schedule
|
94 |
+
|
95 |
+
# SageMath parsed files
|
96 |
+
*.sage.py
|
97 |
+
|
98 |
+
# Environments
|
99 |
+
.env
|
100 |
+
.venv
|
101 |
+
env/
|
102 |
+
venv/
|
103 |
+
ENV/
|
104 |
+
env.bak/
|
105 |
+
venv.bak/
|
106 |
+
|
107 |
+
# Spyder project settings
|
108 |
+
.spyderproject
|
109 |
+
.spyproject
|
110 |
+
|
111 |
+
# Rope project settings
|
112 |
+
.ropeproject
|
113 |
+
|
114 |
+
# mkdocs documentation
|
115 |
+
/site
|
116 |
+
|
117 |
+
# mypy
|
118 |
+
.mypy_cache/
|
119 |
+
|
120 |
+
**/benchmark
|
121 |
+
**/results
|
122 |
+
*.pkl
|
123 |
+
*.log
|
env/SE3Transformer/.gitignore
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data/
|
2 |
+
.DS_Store
|
3 |
+
*wandb/
|
4 |
+
*.pt
|
5 |
+
*.swp
|
6 |
+
|
7 |
+
# added by FAFU
|
8 |
+
.idea/
|
9 |
+
cache/
|
10 |
+
downloaded/
|
11 |
+
*.lprof
|
12 |
+
|
13 |
+
# Byte-compiled / optimized / DLL files
|
14 |
+
__pycache__/
|
15 |
+
*.py[cod]
|
16 |
+
*$py.class
|
17 |
+
|
18 |
+
# C extensions
|
19 |
+
*.so
|
20 |
+
|
21 |
+
# Distribution / packaging
|
22 |
+
.Python
|
23 |
+
build/
|
24 |
+
develop-eggs/
|
25 |
+
dist/
|
26 |
+
downloads/
|
27 |
+
eggs/
|
28 |
+
.eggs/
|
29 |
+
lib/
|
30 |
+
lib64/
|
31 |
+
parts/
|
32 |
+
sdist/
|
33 |
+
var/
|
34 |
+
wheels/
|
35 |
+
*.egg-info/
|
36 |
+
.installed.cfg
|
37 |
+
*.egg
|
38 |
+
MANIFEST
|
39 |
+
|
40 |
+
# PyInstaller
|
41 |
+
# Usually these files are written by a python script from a template
|
42 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
43 |
+
*.manifest
|
44 |
+
*.spec
|
45 |
+
|
46 |
+
# Installer logs
|
47 |
+
pip-log.txt
|
48 |
+
pip-delete-this-directory.txt
|
49 |
+
|
50 |
+
# Unit test / coverage reports
|
51 |
+
htmlcov/
|
52 |
+
.tox/
|
53 |
+
.coverage
|
54 |
+
.coverage.*
|
55 |
+
.cache
|
56 |
+
nosetests.xml
|
57 |
+
coverage.xml
|
58 |
+
*.cover
|
59 |
+
.hypothesis/
|
60 |
+
.pytest_cache/
|
61 |
+
|
62 |
+
# Translations
|
63 |
+
*.mo
|
64 |
+
*.pot
|
65 |
+
|
66 |
+
# Django stuff:
|
67 |
+
*.log
|
68 |
+
local_settings.py
|
69 |
+
db.sqlite3
|
70 |
+
|
71 |
+
# Flask stuff:
|
72 |
+
instance/
|
73 |
+
.webassets-cache
|
74 |
+
|
75 |
+
# Scrapy stuff:
|
76 |
+
.scrapy
|
77 |
+
|
78 |
+
# Sphinx documentation
|
79 |
+
docs/_build/
|
80 |
+
|
81 |
+
# PyBuilder
|
82 |
+
target/
|
83 |
+
|
84 |
+
# Jupyter Notebook
|
85 |
+
.ipynb_checkpoints
|
86 |
+
|
87 |
+
# pyenv
|
88 |
+
.python-version
|
89 |
+
|
90 |
+
# celery beat schedule file
|
91 |
+
celerybeat-schedule
|
92 |
+
|
93 |
+
# SageMath parsed files
|
94 |
+
*.sage.py
|
95 |
+
|
96 |
+
# Environments
|
97 |
+
.env
|
98 |
+
.venv
|
99 |
+
env/
|
100 |
+
venv/
|
101 |
+
ENV/
|
102 |
+
env.bak/
|
103 |
+
venv.bak/
|
104 |
+
|
105 |
+
# Spyder project settings
|
106 |
+
.spyderproject
|
107 |
+
.spyproject
|
108 |
+
|
109 |
+
# Rope project settings
|
110 |
+
.ropeproject
|
111 |
+
|
112 |
+
# mkdocs documentation
|
113 |
+
/site
|
114 |
+
|
115 |
+
# mypy
|
116 |
+
.mypy_cache/
|
117 |
+
|
118 |
+
**/benchmark
|
119 |
+
**/results
|
120 |
+
*.pkl
|
121 |
+
*.log
|
env/SE3Transformer/Dockerfile
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
# run docker daemon with --default-runtime=nvidia for GPU detection during build
|
25 |
+
# multistage build for DGL with CUDA and FP16
|
26 |
+
|
27 |
+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.07-py3
|
28 |
+
|
29 |
+
FROM ${FROM_IMAGE_NAME} AS dgl_builder
|
30 |
+
|
31 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
32 |
+
RUN apt-get update \
|
33 |
+
&& apt-get install -y git build-essential python3-dev make cmake \
|
34 |
+
&& rm -rf /var/lib/apt/lists/*
|
35 |
+
WORKDIR /dgl
|
36 |
+
RUN git clone --branch v0.7.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git .
|
37 |
+
RUN sed -i 's/"35 50 60 70"/"60 70 80"/g' cmake/modules/CUDA.cmake
|
38 |
+
WORKDIR build
|
39 |
+
RUN cmake -DUSE_CUDA=ON -DUSE_FP16=ON ..
|
40 |
+
RUN make -j8
|
41 |
+
|
42 |
+
|
43 |
+
FROM ${FROM_IMAGE_NAME}
|
44 |
+
|
45 |
+
RUN rm -rf /workspace/*
|
46 |
+
WORKDIR /workspace/se3-transformer
|
47 |
+
|
48 |
+
# copy built DGL and install it
|
49 |
+
COPY --from=dgl_builder /dgl ./dgl
|
50 |
+
RUN cd dgl/python && python setup.py install && cd ../.. && rm -rf dgl
|
51 |
+
|
52 |
+
ADD requirements.txt .
|
53 |
+
RUN pip install --no-cache-dir --upgrade --pre pip
|
54 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
55 |
+
ADD . .
|
56 |
+
|
57 |
+
ENV DGLBACKEND=pytorch
|
58 |
+
ENV OMP_NUM_THREADS=1
|
env/SE3Transformer/LICENSE
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2021 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4 |
+
|
5 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
+
|
7 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
env/SE3Transformer/NOTICE
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SE(3)-Transformer PyTorch
|
2 |
+
|
3 |
+
This repository includes software from https://github.com/FabianFuchsML/se3-transformer-public
|
4 |
+
licensed under the MIT License.
|
5 |
+
|
6 |
+
This repository includes software from https://github.com/lucidrains/se3-transformer-pytorch
|
7 |
+
licensed under the MIT License.
|
env/SE3Transformer/README.md
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SE(3)-Transformers For PyTorch
|
2 |
+
|
3 |
+
This repository provides a script and recipe to train the SE(3)-Transformer model to achieve state-of-the-art accuracy. The content of this repository is tested and maintained by NVIDIA.
|
4 |
+
|
5 |
+
## Table Of Contents
|
6 |
+
- [Model overview](#model-overview)
|
7 |
+
* [Model architecture](#model-architecture)
|
8 |
+
* [Default configuration](#default-configuration)
|
9 |
+
* [Feature support matrix](#feature-support-matrix)
|
10 |
+
* [Features](#features)
|
11 |
+
* [Mixed precision training](#mixed-precision-training)
|
12 |
+
* [Enabling mixed precision](#enabling-mixed-precision)
|
13 |
+
* [Enabling TF32](#enabling-tf32)
|
14 |
+
* [Glossary](#glossary)
|
15 |
+
- [Setup](#setup)
|
16 |
+
* [Requirements](#requirements)
|
17 |
+
- [Quick Start Guide](#quick-start-guide)
|
18 |
+
- [Advanced](#advanced)
|
19 |
+
* [Scripts and sample code](#scripts-and-sample-code)
|
20 |
+
* [Parameters](#parameters)
|
21 |
+
* [Command-line options](#command-line-options)
|
22 |
+
* [Getting the data](#getting-the-data)
|
23 |
+
* [Dataset guidelines](#dataset-guidelines)
|
24 |
+
* [Multi-dataset](#multi-dataset)
|
25 |
+
* [Training process](#training-process)
|
26 |
+
* [Inference process](#inference-process)
|
27 |
+
- [Performance](#performance)
|
28 |
+
* [Benchmarking](#benchmarking)
|
29 |
+
* [Training performance benchmark](#training-performance-benchmark)
|
30 |
+
* [Inference performance benchmark](#inference-performance-benchmark)
|
31 |
+
* [Results](#results)
|
32 |
+
* [Training accuracy results](#training-accuracy-results)
|
33 |
+
* [Training accuracy: NVIDIA DGX A100 (8x A100 80GB)](#training-accuracy-nvidia-dgx-a100-8x-a100-80gb)
|
34 |
+
* [Training accuracy: NVIDIA DGX-1 (8x V100 16GB)](#training-accuracy-nvidia-dgx-1-8x-v100-16gb)
|
35 |
+
* [Training stability test](#training-stability-test)
|
36 |
+
* [Training performance results](#training-performance-results)
|
37 |
+
* [Training performance: NVIDIA DGX A100 (8x A100 80GB)](#training-performance-nvidia-dgx-a100-8x-a100-80gb)
|
38 |
+
* [Training performance: NVIDIA DGX-1 (8x V100 16GB)](#training-performance-nvidia-dgx-1-8x-v100-16gb)
|
39 |
+
* [Inference performance results](#inference-performance-results)
|
40 |
+
* [Inference performance: NVIDIA DGX A100 (1x A100 80GB)](#inference-performance-nvidia-dgx-a100-1x-a100-80gb)
|
41 |
+
* [Inference performance: NVIDIA DGX-1 (1x V100 16GB)](#inference-performance-nvidia-dgx-1-1x-v100-16gb)
|
42 |
+
- [Release notes](#release-notes)
|
43 |
+
* [Changelog](#changelog)
|
44 |
+
* [Known issues](#known-issues)
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
## Model overview
|
49 |
+
|
50 |
+
|
51 |
+
The **SE(3)-Transformer** is a Graph Neural Network using a variant of [self-attention](https://arxiv.org/abs/1706.03762v5) for 3D points and graphs processing.
|
52 |
+
This model is [equivariant](https://en.wikipedia.org/wiki/Equivariant_map) under [continuous 3D roto-translations](https://en.wikipedia.org/wiki/Euclidean_group), meaning that when the inputs (graphs or sets of points) rotate in 3D space (or more generally experience a [proper rigid transformation](https://en.wikipedia.org/wiki/Rigid_transformation)), the model outputs either stay invariant or transform with the input.
|
53 |
+
A mathematical guarantee of equivariance is important to ensure stable and predictable performance in the presence of nuisance transformations of the data input and when the problem has some inherent symmetries we want to exploit.
|
54 |
+
|
55 |
+
|
56 |
+
The model is based on the following publications:
|
57 |
+
- [SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks](https://arxiv.org/abs/2006.10503) (NeurIPS 2020) by Fabian B. Fuchs, Daniel E. Worrall, et al.
|
58 |
+
- [Tensor field networks: Rotation- and translation-equivariant neural networks for 3D point clouds](https://arxiv.org/abs/1802.08219) by Nathaniel Thomas, Tess Smidt, et al.
|
59 |
+
|
60 |
+
A follow-up paper explains how this model can be used iteratively, for example, to predict or refine protein structures:
|
61 |
+
|
62 |
+
- [Iterative SE(3)-Transformers](https://arxiv.org/abs/2102.13419) by Fabian B. Fuchs, Daniel E. Worrall, et al.
|
63 |
+
|
64 |
+
Just like [the official implementation](https://github.com/FabianFuchsML/se3-transformer-public), this implementation uses [PyTorch](https://pytorch.org/) and the [Deep Graph Library (DGL)](https://www.dgl.ai/).
|
65 |
+
|
66 |
+
The main differences between this implementation of SE(3)-Transformers and the official one are the following:
|
67 |
+
|
68 |
+
- Training and inference support for multiple GPUs
|
69 |
+
- Training and inference support for [Mixed Precision](https://arxiv.org/abs/1710.03740)
|
70 |
+
- The [QM9 dataset from DGL](https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset) is used and automatically downloaded
|
71 |
+
- Significantly increased throughput
|
72 |
+
- Significantly reduced memory consumption
|
73 |
+
- The use of layer normalization in the fully connected radial profile layers is an option (`--use_layer_norm`), off by default
|
74 |
+
- The use of equivariant normalization between attention layers is an option (`--norm`), off by default
|
75 |
+
- The [spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonic) and [Clebsch–Gordan coefficients](https://en.wikipedia.org/wiki/Clebsch%E2%80%93Gordan_coefficients), used to compute bases matrices, are computed with the [e3nn library](https://e3nn.org/)
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
This model enables you to predict quantum chemical properties of small organic molecules in the [QM9 dataset](https://www.nature.com/articles/sdata201422).
|
80 |
+
In this case, the exploited symmetry is that these properties do not depend on the orientation or position of the molecules in space.
|
81 |
+
|
82 |
+
|
83 |
+
This model is trained with mixed precision using Tensor Cores on NVIDIA Volta, NVIDIA Turing, and the NVIDIA Ampere GPU architectures. Therefore, researchers can get results up to 1.5x faster than training without Tensor Cores while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
|
84 |
+
|
85 |
+
### Model architecture
|
86 |
+
|
87 |
+
The model consists of stacked layers of equivariant graph self-attention and equivariant normalization.
|
88 |
+
Lastly, a Tensor Field Network convolution is applied to obtain invariant features. Graph pooling (mean or max over the nodes) is applied to these features, and the result is fed to a final MLP to get scalar predictions.
|
89 |
+
|
90 |
+
In this setup, the model is a graph-to-scalar network. The pooling can be removed to obtain a graph-to-graph network, and the final TFN can be modified to output features of any type (invariant scalars, 3D vectors, ...).
|
91 |
+
|
92 |
+
|
93 |
+

|
94 |
+
|
95 |
+
|
96 |
+
### Default configuration
|
97 |
+
|
98 |
+
|
99 |
+
SE(3)-Transformers introduce a self-attention layer for graphs that is equivariant to 3D roto-translations. It achieves this by leveraging Tensor Field Networks to build attention weights that are invariant and attention values that are equivariant.
|
100 |
+
Combining the equivariant values with the invariant weights gives rise to an equivariant output. This output is normalized while preserving equivariance thanks to equivariant normalization layers operating on feature norms.
|
101 |
+
|
102 |
+
|
103 |
+
The following features were implemented in this model:
|
104 |
+
|
105 |
+
- Support for edge features of any degree (1D, 3D, 5D, ...), whereas the official implementation only supports scalar invariant edge features (degree 0). Edge features with a degree greater than one are
|
106 |
+
concatenated to node features of the same degree. This is required in order to reproduce published results on point cloud processing.
|
107 |
+
- Data-parallel multi-GPU training (DDP)
|
108 |
+
- Mixed precision training (autocast, gradient scaling)
|
109 |
+
- Gradient accumulation
|
110 |
+
- Model checkpointing
|
111 |
+
|
112 |
+
|
113 |
+
The following performance optimizations were implemented in this model:
|
114 |
+
|
115 |
+
|
116 |
+
**General optimizations**
|
117 |
+
|
118 |
+
- The option is provided to precompute bases at the beginning of the training instead of computing them at the beginning of each forward pass (`--precompute_bases`)
|
119 |
+
- The bases computation is just-in-time (JIT) compiled with `torch.jit.script`
|
120 |
+
- The Clebsch-Gordon coefficients are cached in RAM
|
121 |
+
|
122 |
+
|
123 |
+
**Tensor Field Network optimizations**
|
124 |
+
|
125 |
+
- The last layer of each radial profile network does not add any bias in order to avoid large broadcasting operations
|
126 |
+
- The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers
|
127 |
+
- When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions)
|
128 |
+
- Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met
|
129 |
+
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`)
|
130 |
+
|
131 |
+
**Self-attention optimizations**
|
132 |
+
|
133 |
+
- Attention keys and values are computed by a single partial TFN graph convolution in each attention layer instead of two
|
134 |
+
- Graph operations for different output degrees may be fused together if conditions are met
|
135 |
+
|
136 |
+
|
137 |
+
**Normalization optimizations**
|
138 |
+
|
139 |
+
- The equivariant normalization layer is optimized from multiple layer normalizations to a group normalization on fused norms when certain conditions are met
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
Competitive training results and analysis are provided for the following hyperparameters (identical to the ones in the original publication):
|
144 |
+
- Number of layers: 7
|
145 |
+
- Number of degrees: 4
|
146 |
+
- Number of channels: 32
|
147 |
+
- Number of attention heads: 8
|
148 |
+
- Channels division: 2
|
149 |
+
- Use of equivariant normalization: true
|
150 |
+
- Use of layer normalization: true
|
151 |
+
- Pooling: max
|
152 |
+
|
153 |
+
|
154 |
+
### Feature support matrix
|
155 |
+
|
156 |
+
This model supports the following features::
|
157 |
+
|
158 |
+
| Feature | SE(3)-Transformer
|
159 |
+
|-----------------------|--------------------------
|
160 |
+
|Automatic mixed precision (AMP) | Yes
|
161 |
+
|Distributed data parallel (DDP) | Yes
|
162 |
+
|
163 |
+
#### Features
|
164 |
+
|
165 |
+
|
166 |
+
**Distributed data parallel (DDP)**
|
167 |
+
|
168 |
+
[DistributedDataParallel (DDP)](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implements data parallelism at the module level that can run across multiple GPUs or machines.
|
169 |
+
|
170 |
+
**Automatic Mixed Precision (AMP)**
|
171 |
+
|
172 |
+
This implementation uses the native PyTorch AMP implementation of mixed precision training. It allows us to use FP16 training with FP32 master weights by modifying just a few lines of code. A detailed explanation of mixed precision can be found in the next section.
|
173 |
+
|
174 |
+
### Mixed precision training
|
175 |
+
|
176 |
+
Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in NVIDIA Volta, and following with both the NVIDIA Turing and NVIDIA Ampere Architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using [mixed precision training](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) previously required two steps:
|
177 |
+
1. Porting the model to use the FP16 data type where appropriate.
|
178 |
+
2. Adding loss scaling to preserve small gradient values.
|
179 |
+
|
180 |
+
AMP enables mixed precision training on NVIDIA Volta, NVIDIA Turing, and NVIDIA Ampere GPU architectures automatically. The PyTorch framework code makes all necessary model changes internally.
|
181 |
+
|
182 |
+
For information about:
|
183 |
+
- How to train using mixed precision, refer to the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) documentation.
|
184 |
+
- Techniques used for mixed precision training, refer to the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.
|
185 |
+
- APEX tools for mixed precision training, refer to the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/).
|
186 |
+
|
187 |
+
#### Enabling mixed precision
|
188 |
+
|
189 |
+
Mixed precision is enabled in PyTorch by using the native [Automatic Mixed Precision package](https://pytorch.org/docs/stable/amp.html), which casts variables to half-precision upon retrieval while storing variables in single-precision format. Furthermore, to preserve small gradient magnitudes in backpropagation, a [loss scaling](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#lossscaling) step must be included when applying gradients. In PyTorch, loss scaling can be applied automatically using a `GradScaler`.
|
190 |
+
Automatic Mixed Precision makes all the adjustments internally in PyTorch, providing two benefits over manual operations. First, programmers need not modify network model code, reducing development and maintenance effort. Second, using AMP maintains forward and backward compatibility with all the APIs for defining and running PyTorch models.
|
191 |
+
|
192 |
+
To enable mixed precision, you can simply use the `--amp` flag when running the training or inference scripts.
|
193 |
+
|
194 |
+
#### Enabling TF32
|
195 |
+
|
196 |
+
TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math, also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on NVIDIA Volta GPUs.
|
197 |
+
|
198 |
+
TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models that require a high dynamic range for weights or activations.
|
199 |
+
|
200 |
+
For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) blog post.
|
201 |
+
|
202 |
+
TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by default.
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
### Glossary
|
207 |
+
|
208 |
+
**Degree (type)**
|
209 |
+
|
210 |
+
In the model, every feature (input, output and hidden) transforms in an equivariant way in relation to the input graph. When we define a feature, we need to choose, in addition to the number of channels, which transformation rule it obeys.
|
211 |
+
|
212 |
+
The degree or type of a feature is a positive integer that describes how this feature transforms when the input rotates in 3D.
|
213 |
+
|
214 |
+
This is related to [irreducible representations](https://en.wikipedia.org/wiki/Irreducible_representation) of different rotation orders.
|
215 |
+
|
216 |
+
The degree of a feature determines its dimensionality. A type-d feature has a dimensionality of 2d+1.
|
217 |
+
|
218 |
+
Some common examples include:
|
219 |
+
- Degree 0: 1D scalars invariant to rotation
|
220 |
+
- Degree 1: 3D vectors that rotate according to 3D rotation matrices
|
221 |
+
- Degree 2: 5D vectors that rotate according to 5D [Wigner-D matrices](https://en.wikipedia.org/wiki/Wigner_D-matrix). These can represent symmetric traceless 3x3 matrices.
|
222 |
+
|
223 |
+
**Fiber**
|
224 |
+
|
225 |
+
A fiber can be viewed as a representation of a set of features of different types or degrees (positive integers), where each feature type transforms according to its rule.
|
226 |
+
|
227 |
+
In this repository, a fiber can be seen as a dictionary with degrees as keys and numbers of channels as values.
|
228 |
+
|
229 |
+
**Multiplicity**
|
230 |
+
|
231 |
+
The multiplicity of a feature of a given type is the number of channels of this feature.
|
232 |
+
|
233 |
+
**Tensor Field Network**
|
234 |
+
|
235 |
+
A [Tensor Field Network](https://arxiv.org/abs/1802.08219) is a kind of equivariant graph convolution that can combine features of different degrees and produce new ones while preserving equivariance thanks to [tensor products](https://en.wikipedia.org/wiki/Tensor_product).
|
236 |
+
|
237 |
+
**Equivariance**
|
238 |
+
|
239 |
+
[Equivariance](https://en.wikipedia.org/wiki/Equivariant_map) is a property of a function of model stating that applying a symmetry transformation to the input and then computing the function produces the same result as computing the function and then applying the transformation to the output.
|
240 |
+
|
241 |
+
In the case of SE(3)-Transformer, the symmetry group is the group of continuous roto-translations (SE(3)).
|
242 |
+
|
243 |
+
## Setup
|
244 |
+
|
245 |
+
The following section lists the requirements that you need to meet in order to start training the SE(3)-Transformer model.
|
246 |
+
|
247 |
+
### Requirements
|
248 |
+
|
249 |
+
This repository contains a Dockerfile which extends the PyTorch 21.07 NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
|
250 |
+
- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
|
251 |
+
- PyTorch 21.07+ NGC container
|
252 |
+
- Supported GPUs:
|
253 |
+
- [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
|
254 |
+
- [NVIDIA Turing architecture](https://www.nvidia.com/en-us/design-visualization/technologies/turing-architecture/)
|
255 |
+
- [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
|
256 |
+
|
257 |
+
For more information about how to get started with NGC containers, refer to the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
|
258 |
+
- [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html)
|
259 |
+
- [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry)
|
260 |
+
- [Running PyTorch](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/running.html#running)
|
261 |
+
|
262 |
+
For those unable to use the PyTorch NGC container to set up the required environment or create your own container, refer to the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
|
263 |
+
|
264 |
+
## Quick Start Guide
|
265 |
+
|
266 |
+
To train your model using mixed or TF32 precision with Tensor Cores or FP32, perform the following steps using the default parameters of the SE(3)-Transformer model on the QM9 dataset. For the specifics concerning training and inference, refer to the [Advanced](#advanced) section.
|
267 |
+
|
268 |
+
1. Clone the repository.
|
269 |
+
```
|
270 |
+
git clone https://github.com/NVIDIA/DeepLearningExamples
|
271 |
+
cd DeepLearningExamples/PyTorch/DrugDiscovery/SE3Transformer
|
272 |
+
```
|
273 |
+
|
274 |
+
2. Build the `se3-transformer` PyTorch NGC container.
|
275 |
+
```
|
276 |
+
docker build -t se3-transformer .
|
277 |
+
```
|
278 |
+
|
279 |
+
3. Start an interactive session in the NGC container to run training/inference.
|
280 |
+
```
|
281 |
+
mkdir -p results
|
282 |
+
docker run -it --runtime=nvidia --shm-size=8g --ulimit memlock=-1 --ulimit stack=67108864 --rm -v ${PWD}/results:/results se3-transformer:latest
|
283 |
+
```
|
284 |
+
|
285 |
+
4. Start training.
|
286 |
+
```
|
287 |
+
bash scripts/train.sh
|
288 |
+
```
|
289 |
+
|
290 |
+
5. Start inference/predictions.
|
291 |
+
```
|
292 |
+
bash scripts/predict.sh
|
293 |
+
```
|
294 |
+
|
295 |
+
|
296 |
+
Now that you have your model trained and evaluated, you can choose to compare your training results with our [Training accuracy results](#training-accuracy-results). You can also choose to benchmark your performance to [Training performance benchmark](#training-performance-results) or [Inference performance benchmark](#inference-performance-results). Following the steps in these sections will ensure that you achieve the same accuracy and performance results as stated in the [Results](#results) section.
|
297 |
+
|
298 |
+
## Advanced
|
299 |
+
|
300 |
+
The following sections provide greater details of the dataset, running training and inference, and the training results.
|
301 |
+
|
302 |
+
### Scripts and sample code
|
303 |
+
|
304 |
+
In the root directory, the most important files are:
|
305 |
+
- `Dockerfile`: container with the basic set of dependencies to run SE(3)-Transformers
|
306 |
+
- `requirements.txt`: set of extra requirements to run SE(3)-Transformers
|
307 |
+
- `se3_transformer/data_loading/qm9.py`: QM9 data loading and preprocessing, as well as bases precomputation
|
308 |
+
- `se3_transformer/model/layers/`: directory containing model architecture layers
|
309 |
+
- `se3_transformer/model/transformer.py`: main Transformer module
|
310 |
+
- `se3_transformer/model/basis.py`: logic for computing bases matrices
|
311 |
+
- `se3_transformer/runtime/training.py`: training script, to be run as a python module
|
312 |
+
- `se3_transformer/runtime/inference.py`: inference script, to be run as a python module
|
313 |
+
- `se3_transformer/runtime/metrics.py`: MAE metric with support for multi-GPU synchronization
|
314 |
+
- `se3_transformer/runtime/loggers.py`: [DLLogger](https://github.com/NVIDIA/dllogger) and [W&B](wandb.ai/) loggers
|
315 |
+
|
316 |
+
|
317 |
+
### Parameters
|
318 |
+
|
319 |
+
The complete list of the available parameters for the `training.py` script contains:
|
320 |
+
|
321 |
+
**General**
|
322 |
+
|
323 |
+
- `--epochs`: Number of training epochs (default: `100` for single-GPU)
|
324 |
+
- `--batch_size`: Batch size (default: `240`)
|
325 |
+
- `--seed`: Set a seed globally (default: `None`)
|
326 |
+
- `--num_workers`: Number of dataloading workers (default: `8`)
|
327 |
+
- `--amp`: Use Automatic Mixed Precision (default `false`)
|
328 |
+
- `--gradient_clip`: Clipping of the gradient norms (default: `None`)
|
329 |
+
- `--accumulate_grad_batches`: Gradient accumulation (default: `1`)
|
330 |
+
- `--ckpt_interval`: Save a checkpoint every N epochs (default: `-1`)
|
331 |
+
- `--eval_interval`: Do an evaluation round every N epochs (default: `1`)
|
332 |
+
- `--silent`: Minimize stdout output (default: `false`)
|
333 |
+
|
334 |
+
**Paths**
|
335 |
+
|
336 |
+
- `--data_dir`: Directory where the data is located or should be downloaded (default: `./data`)
|
337 |
+
- `--log_dir`: Directory where the results logs should be saved (default: `/results`)
|
338 |
+
- `--save_ckpt_path`: File where the checkpoint should be saved (default: `None`)
|
339 |
+
- `--load_ckpt_path`: File of the checkpoint to be loaded (default: `None`)
|
340 |
+
|
341 |
+
**Optimizer**
|
342 |
+
|
343 |
+
- `--optimizer`: Optimizer to use (default: `adam`)
|
344 |
+
- `--learning_rate`: Learning rate to use (default: `0.002` for single-GPU)
|
345 |
+
- `--momentum`: Momentum to use (default: `0.9`)
|
346 |
+
- `--weight_decay`: Weight decay to use (default: `0.1`)
|
347 |
+
|
348 |
+
**QM9 dataset**
|
349 |
+
|
350 |
+
- `--task`: Regression task to train on (default: `homo`)
|
351 |
+
- `--precompute_bases`: Precompute bases at the beginning of the script during dataset initialization, instead of computing them at the beginning of each forward pass (default: `false`)
|
352 |
+
|
353 |
+
**Model architecture**
|
354 |
+
|
355 |
+
- `--num_layers`: Number of stacked Transformer layers (default: `7`)
|
356 |
+
- `--num_heads`: Number of heads in self-attention (default: `8`)
|
357 |
+
- `--channels_div`: Channels division before feeding to attention layer (default: `2`)
|
358 |
+
- `--pooling`: Type of graph pooling (default: `max`)
|
359 |
+
- `--norm`: Apply a normalization layer after each attention block (default: `false`)
|
360 |
+
- `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`)
|
361 |
+
- `--low_memory`: If true, will use fused ops that are slower but use less memory (expect 25 percent less memory). Only has an effect if AMP is enabled on NVIDIA Volta GPUs or if running on Ampere GPUs (default: `false`)
|
362 |
+
- `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`)
|
363 |
+
- `--num_channels`: Number of channels for the hidden features (default: `32`)
|
364 |
+
|
365 |
+
|
366 |
+
### Command-line options
|
367 |
+
|
368 |
+
To show the full list of available options and their descriptions, use the `-h` or `--help` command-line option, for example: `python -m se3_transformer.runtime.training --help`.
|
369 |
+
|
370 |
+
|
371 |
+
### Dataset guidelines
|
372 |
+
|
373 |
+
#### Demo dataset
|
374 |
+
|
375 |
+
The SE(3)-Transformer was trained on the QM9 dataset.
|
376 |
+
|
377 |
+
The QM9 dataset is hosted on DGL servers and downloaded (38MB) automatically when needed. By default, it is stored in the `./data` directory, but this location can be changed with the `--data_dir` argument.
|
378 |
+
|
379 |
+
The dataset is saved as a `qm9_edge.npz` file and converted to DGL graphs at runtime.
|
380 |
+
|
381 |
+
As input features, we use:
|
382 |
+
- Node features (6D):
|
383 |
+
- One-hot-encoded atom type (5D) (atom types: H, C, N, O, F)
|
384 |
+
- Number of protons of each atom (1D)
|
385 |
+
- Edge features: one-hot-encoded bond type (4D) (bond types: single, double, triple, aromatic)
|
386 |
+
- The relative positions between adjacent nodes (atoms)
|
387 |
+
|
388 |
+
#### Custom datasets
|
389 |
+
|
390 |
+
To use this network on a new dataset, you can extend the `DataModule` class present in `se3_transformer/data_loading/data_module.py`.
|
391 |
+
|
392 |
+
Your custom collate function should return a tuple with:
|
393 |
+
|
394 |
+
- A (batched) DGLGraph object
|
395 |
+
- A dictionary of node features ({‘{degree}’: tensor})
|
396 |
+
- A dictionary of edge features ({‘{degree}’: tensor})
|
397 |
+
- (Optional) Precomputed bases as a dictionary
|
398 |
+
- Labels as a tensor
|
399 |
+
|
400 |
+
You can then modify the `training.py` and `inference.py` scripts to use your new data module.
|
401 |
+
|
402 |
+
### Training process
|
403 |
+
|
404 |
+
The training script is `se3_transformer/runtime/training.py`, to be run as a module: `python -m se3_transformer.runtime.training`.
|
405 |
+
|
406 |
+
**Logs**
|
407 |
+
|
408 |
+
By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`.
|
409 |
+
|
410 |
+
You can connect your existing Weights & Biases account by setting the `WANDB_API_KEY` environment variable.
|
411 |
+
|
412 |
+
**Checkpoints**
|
413 |
+
|
414 |
+
The argument `--save_ckpt_path` can be set to the path of the file where the checkpoints should be saved.
|
415 |
+
`--ckpt_interval` can also be set to the interval (in the number of epochs) between checkpoints.
|
416 |
+
|
417 |
+
**Evaluation**
|
418 |
+
|
419 |
+
The evaluation metric is the Mean Absolute Error (MAE).
|
420 |
+
|
421 |
+
`--eval_interval` can be set to the interval (in the number of epochs) between evaluation rounds. By default, an evaluation round is performed after each epoch.
|
422 |
+
|
423 |
+
**Automatic Mixed Precision**
|
424 |
+
|
425 |
+
To enable Mixed Precision training, add the `--amp` flag.
|
426 |
+
|
427 |
+
**Multi-GPU and multi-node**
|
428 |
+
|
429 |
+
The training script supports the PyTorch elastic launcher to run on multiple GPUs or nodes. Refer to the [official documentation](https://pytorch.org/docs/1.9.0/elastic/run.html).
|
430 |
+
|
431 |
+
For example, to train on all available GPUs with AMP:
|
432 |
+
|
433 |
+
```
|
434 |
+
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --module se3_transformer.runtime.training --amp
|
435 |
+
```
|
436 |
+
|
437 |
+
|
438 |
+
### Inference process
|
439 |
+
|
440 |
+
Inference can be run by using the `se3_transformer.runtime.inference` python module.
|
441 |
+
|
442 |
+
The inference script is `se3_transformer/runtime/inference.py`, to be run as a module: `python -m se3_transformer.runtime.inference`. It requires a pre-trained model checkpoint (to be passed as `--load_ckpt_path`).
|
443 |
+
|
444 |
+
|
445 |
+
## Performance
|
446 |
+
|
447 |
+
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
|
448 |
+
|
449 |
+
### Benchmarking
|
450 |
+
|
451 |
+
The following section shows how to run benchmarks measuring the model performance in training and inference modes.
|
452 |
+
|
453 |
+
#### Training performance benchmark
|
454 |
+
|
455 |
+
To benchmark the training performance on a specific batch size, run `bash scripts/benchmarck_train.sh {BATCH_SIZE}` for single GPU, and `bash scripts/benchmarck_train_multi_gpu.sh {BATCH_SIZE}` for multi-GPU.
|
456 |
+
|
457 |
+
#### Inference performance benchmark
|
458 |
+
|
459 |
+
To benchmark the inference performance on a specific batch size, run `bash scripts/benchmarck_inference.sh {BATCH_SIZE}`.
|
460 |
+
|
461 |
+
### Results
|
462 |
+
|
463 |
+
|
464 |
+
The following sections provide details on how we achieved our performance and accuracy in training and inference.
|
465 |
+
|
466 |
+
#### Training accuracy results
|
467 |
+
|
468 |
+
##### Training accuracy: NVIDIA DGX A100 (8x A100 80GB)
|
469 |
+
|
470 |
+
Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 (8x A100 80GB) GPUs.
|
471 |
+
|
472 |
+
| GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) |
|
473 |
+
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
474 |
+
| 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x |
|
475 |
+
| 8 | 240 | 0.03417 | 0.03424 | 15min | 12min | 1.25x |
|
476 |
+
|
477 |
+
|
478 |
+
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
|
479 |
+
|
480 |
+
Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
|
481 |
+
|
482 |
+
| GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) |
|
483 |
+
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
484 |
+
| 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x |
|
485 |
+
| 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
|
486 |
+
|
487 |
+
|
488 |
+
#### Training performance results
|
489 |
+
|
490 |
+
##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
|
491 |
+
|
492 |
+
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 8x A100 80GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
|
493 |
+
|
494 |
+
| GPUs | Batch size / GPU | Throughput - TF32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (mixed precision - TF32) | Weak scaling - TF32 | Weak scaling - mixed precision |
|
495 |
+
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
496 |
+
| 1 | 240 | 2.21 | 2.92 | 1.32x | | |
|
497 |
+
| 1 | 120 | 1.81 | 2.04 | 1.13x | | |
|
498 |
+
| 8 | 240 | 17.15 | 22.95 | 1.34x | 7.76 | 7.86 |
|
499 |
+
| 8 | 120 | 13.89 | 15.62 | 1.12x | 7.67 | 7.66 |
|
500 |
+
|
501 |
+
|
502 |
+
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
503 |
+
|
504 |
+
|
505 |
+
##### Training performance: NVIDIA DGX-1 (8x V100 16GB)
|
506 |
+
|
507 |
+
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 8x V100 16GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
|
508 |
+
|
509 |
+
| GPUs | Batch size / GPU | Throughput - FP32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (FP32 - mixed precision) | Weak scaling - FP32 | Weak scaling - mixed precision |
|
510 |
+
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
511 |
+
| 1 | 240 | 1.25 | 1.88 | 1.50x | | |
|
512 |
+
| 1 | 120 | 1.03 | 1.41 | 1.37x | | |
|
513 |
+
| 8 | 240 | 9.33 | 14.02 | 1.50x | 7.46 | 7.46 |
|
514 |
+
| 8 | 120 | 7.39 | 9.41 | 1.27x | 7.17 | 6.67 |
|
515 |
+
|
516 |
+
|
517 |
+
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
518 |
+
|
519 |
+
|
520 |
+
#### Inference performance results
|
521 |
+
|
522 |
+
|
523 |
+
##### Inference performance: NVIDIA DGX A100 (1x A100 80GB)
|
524 |
+
|
525 |
+
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 1x A100 80GB GPU.
|
526 |
+
|
527 |
+
FP16
|
528 |
+
|
529 |
+
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|
530 |
+
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
|
531 |
+
| 1600 | 11.60 | 140.94 | 138.29 | 140.12 | 386.40 |
|
532 |
+
| 800 | 10.74 | 75.69 | 75.74 | 76.50 | 79.77 |
|
533 |
+
| 400 | 8.86 | 45.57 | 46.11 | 46.60 | 49.97 |
|
534 |
+
|
535 |
+
TF32
|
536 |
+
|
537 |
+
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|
538 |
+
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
|
539 |
+
| 1600 | 8.58 | 189.20 | 186.39 | 187.71 | 420.28 |
|
540 |
+
| 800 | 8.28 | 97.56 | 97.20 | 97.73 | 101.13 |
|
541 |
+
| 400 | 7.55 | 53.38 | 53.72 | 54.48 | 56.62 |
|
542 |
+
|
543 |
+
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
544 |
+
|
545 |
+
|
546 |
+
|
547 |
+
##### Inference performance: NVIDIA DGX-1 (1x V100 16GB)
|
548 |
+
|
549 |
+
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 1x V100 16GB GPU.
|
550 |
+
|
551 |
+
FP16
|
552 |
+
|
553 |
+
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|
554 |
+
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
|
555 |
+
| 1600 | 6.42 | 254.54 | 247.97 | 249.29 | 721.15 |
|
556 |
+
| 800 | 6.13 | 132.07 | 131.90 | 132.70 | 140.15 |
|
557 |
+
| 400 | 5.37 | 75.12 | 76.01 | 76.66 | 79.90 |
|
558 |
+
|
559 |
+
FP32
|
560 |
+
|
561 |
+
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|
562 |
+
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
|
563 |
+
| 1600 | 3.39 | 475.86 | 473.82 | 475.64 | 891.18 |
|
564 |
+
| 800 | 3.36 | 239.17 | 240.64 | 241.65 | 243.70 |
|
565 |
+
| 400 | 3.17 | 126.67 | 128.19 | 128.82 | 130.54 |
|
566 |
+
|
567 |
+
|
568 |
+
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
569 |
+
|
570 |
+
|
571 |
+
## Release notes
|
572 |
+
|
573 |
+
### Changelog
|
574 |
+
|
575 |
+
August 2021
|
576 |
+
- Initial release
|
577 |
+
|
578 |
+
### Known issues
|
579 |
+
|
580 |
+
If you encounter `OSError: [Errno 12] Cannot allocate memory` during the Dataloader iterator creation (more precisely during the `fork()`, this is most likely due to the use of the `--precompute_bases` flag. If you cannot add more RAM or Swap to your machine, it is recommended to turn off bases precomputation by removing the `--precompute_bases` flag or using `--precompute_bases false`.
|
env/SE3Transformer/build/lib/se3_transformer/__init__.py
ADDED
File without changes
|
env/SE3Transformer/build/lib/se3_transformer/data_loading/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .qm9 import QM9DataModule
|
env/SE3Transformer/build/lib/se3_transformer/data_loading/data_module.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import torch.distributed as dist
|
25 |
+
from abc import ABC
|
26 |
+
from torch.utils.data import DataLoader, DistributedSampler, Dataset
|
27 |
+
|
28 |
+
from se3_transformer.runtime.utils import get_local_rank
|
29 |
+
|
30 |
+
|
31 |
+
def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader:
|
32 |
+
# Classic or distributed dataloader depending on the context
|
33 |
+
sampler = DistributedSampler(dataset, shuffle=shuffle) if dist.is_initialized() else None
|
34 |
+
return DataLoader(dataset, shuffle=(shuffle and sampler is None), sampler=sampler, **kwargs)
|
35 |
+
|
36 |
+
|
37 |
+
class DataModule(ABC):
|
38 |
+
""" Abstract DataModule. Children must define self.ds_{train | val | test}. """
|
39 |
+
|
40 |
+
def __init__(self, **dataloader_kwargs):
|
41 |
+
super().__init__()
|
42 |
+
if get_local_rank() == 0:
|
43 |
+
self.prepare_data()
|
44 |
+
|
45 |
+
# Wait until rank zero has prepared the data (download, preprocessing, ...)
|
46 |
+
if dist.is_initialized():
|
47 |
+
dist.barrier(device_ids=[get_local_rank()])
|
48 |
+
|
49 |
+
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
|
50 |
+
self.ds_train, self.ds_val, self.ds_test = None, None, None
|
51 |
+
|
52 |
+
def prepare_data(self):
|
53 |
+
""" Method called only once per node. Put here any downloading or preprocessing """
|
54 |
+
pass
|
55 |
+
|
56 |
+
def train_dataloader(self) -> DataLoader:
|
57 |
+
return _get_dataloader(self.ds_train, shuffle=True, **self.dataloader_kwargs)
|
58 |
+
|
59 |
+
def val_dataloader(self) -> DataLoader:
|
60 |
+
return _get_dataloader(self.ds_val, shuffle=False, **self.dataloader_kwargs)
|
61 |
+
|
62 |
+
def test_dataloader(self) -> DataLoader:
|
63 |
+
return _get_dataloader(self.ds_test, shuffle=False, **self.dataloader_kwargs)
|
env/SE3Transformer/build/lib/se3_transformer/data_loading/qm9.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
from typing import Tuple
|
24 |
+
|
25 |
+
import dgl
|
26 |
+
import pathlib
|
27 |
+
import torch
|
28 |
+
from dgl.data import QM9EdgeDataset
|
29 |
+
from dgl import DGLGraph
|
30 |
+
from torch import Tensor
|
31 |
+
from torch.utils.data import random_split, DataLoader, Dataset
|
32 |
+
from tqdm import tqdm
|
33 |
+
|
34 |
+
from se3_transformer.data_loading.data_module import DataModule
|
35 |
+
from se3_transformer.model.basis import get_basis
|
36 |
+
from se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores
|
37 |
+
|
38 |
+
|
39 |
+
def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor:
|
40 |
+
x = qm9_graph.ndata['pos']
|
41 |
+
src, dst = qm9_graph.edges()
|
42 |
+
rel_pos = x[dst] - x[src]
|
43 |
+
return rel_pos
|
44 |
+
|
45 |
+
|
46 |
+
def _get_split_sizes(full_dataset: Dataset) -> Tuple[int, int, int]:
|
47 |
+
len_full = len(full_dataset)
|
48 |
+
len_train = 100_000
|
49 |
+
len_test = int(0.1 * len_full)
|
50 |
+
len_val = len_full - len_train - len_test
|
51 |
+
return len_train, len_val, len_test
|
52 |
+
|
53 |
+
|
54 |
+
class QM9DataModule(DataModule):
|
55 |
+
"""
|
56 |
+
Datamodule wrapping https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset
|
57 |
+
Training set is 100k molecules. Test set is 10% of the dataset. Validation set is the rest.
|
58 |
+
This includes all the molecules from QM9 except the ones that are uncharacterized.
|
59 |
+
"""
|
60 |
+
|
61 |
+
NODE_FEATURE_DIM = 6
|
62 |
+
EDGE_FEATURE_DIM = 4
|
63 |
+
|
64 |
+
def __init__(self,
|
65 |
+
data_dir: pathlib.Path,
|
66 |
+
task: str = 'homo',
|
67 |
+
batch_size: int = 240,
|
68 |
+
num_workers: int = 8,
|
69 |
+
num_degrees: int = 4,
|
70 |
+
amp: bool = False,
|
71 |
+
precompute_bases: bool = False,
|
72 |
+
**kwargs):
|
73 |
+
self.data_dir = data_dir # This needs to be before __init__ so that prepare_data has access to it
|
74 |
+
super().__init__(batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate)
|
75 |
+
self.amp = amp
|
76 |
+
self.task = task
|
77 |
+
self.batch_size = batch_size
|
78 |
+
self.num_degrees = num_degrees
|
79 |
+
|
80 |
+
qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
|
81 |
+
if precompute_bases:
|
82 |
+
bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
|
83 |
+
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size,
|
84 |
+
num_workers=num_workers, **qm9_kwargs)
|
85 |
+
else:
|
86 |
+
full_dataset = QM9EdgeDataset(**qm9_kwargs)
|
87 |
+
|
88 |
+
self.ds_train, self.ds_val, self.ds_test = random_split(full_dataset, _get_split_sizes(full_dataset),
|
89 |
+
generator=torch.Generator().manual_seed(0))
|
90 |
+
|
91 |
+
train_targets = full_dataset.targets[self.ds_train.indices, full_dataset.label_keys[0]]
|
92 |
+
self.targets_mean = train_targets.mean()
|
93 |
+
self.targets_std = train_targets.std()
|
94 |
+
|
95 |
+
def prepare_data(self):
|
96 |
+
# Download the QM9 preprocessed data
|
97 |
+
QM9EdgeDataset(verbose=True, raw_dir=str(self.data_dir))
|
98 |
+
|
99 |
+
def _collate(self, samples):
|
100 |
+
graphs, y, *bases = map(list, zip(*samples))
|
101 |
+
batched_graph = dgl.batch(graphs)
|
102 |
+
edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]}
|
103 |
+
batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
|
104 |
+
# get node features
|
105 |
+
node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]}
|
106 |
+
targets = (torch.cat(y) - self.targets_mean) / self.targets_std
|
107 |
+
|
108 |
+
if bases:
|
109 |
+
# collate bases
|
110 |
+
all_bases = {
|
111 |
+
key: torch.cat([b[key] for b in bases[0]], dim=0)
|
112 |
+
for key in bases[0][0].keys()
|
113 |
+
}
|
114 |
+
|
115 |
+
return batched_graph, node_feats, edge_feats, all_bases, targets
|
116 |
+
else:
|
117 |
+
return batched_graph, node_feats, edge_feats, targets
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def add_argparse_args(parent_parser):
|
121 |
+
parser = parent_parser.add_argument_group("QM9 dataset")
|
122 |
+
parser.add_argument('--task', type=str, default='homo', const='homo', nargs='?',
|
123 |
+
choices=['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
|
124 |
+
'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'],
|
125 |
+
help='Regression task to train on')
|
126 |
+
parser.add_argument('--precompute_bases', type=str2bool, nargs='?', const=True, default=False,
|
127 |
+
help='Precompute bases at the beginning of the script during dataset initialization,'
|
128 |
+
' instead of computing them at the beginning of each forward pass.')
|
129 |
+
return parent_parser
|
130 |
+
|
131 |
+
def __repr__(self):
|
132 |
+
return f'QM9({self.task})'
|
133 |
+
|
134 |
+
|
135 |
+
class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
|
136 |
+
""" Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
|
137 |
+
|
138 |
+
def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs):
|
139 |
+
"""
|
140 |
+
:param bases_kwargs: Arguments to feed the bases computation function
|
141 |
+
:param batch_size: Batch size to use when iterating over the dataset for computing bases
|
142 |
+
"""
|
143 |
+
self.bases_kwargs = bases_kwargs
|
144 |
+
self.batch_size = batch_size
|
145 |
+
self.bases = None
|
146 |
+
self.num_workers = num_workers
|
147 |
+
super().__init__(*args, **kwargs)
|
148 |
+
|
149 |
+
def load(self):
|
150 |
+
super().load()
|
151 |
+
# Iterate through the dataset and compute bases (pairwise only)
|
152 |
+
# Potential improvement: use multi-GPU and gather
|
153 |
+
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
|
154 |
+
collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
|
155 |
+
bases = []
|
156 |
+
for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',
|
157 |
+
disable=get_local_rank() != 0):
|
158 |
+
rel_pos = _get_relative_pos(graph)
|
159 |
+
# Compute the bases with the GPU but convert the result to CPU to store in RAM
|
160 |
+
bases.append({k: v.cpu() for k, v in get_basis(rel_pos.cuda(), **self.bases_kwargs).items()})
|
161 |
+
self.bases = bases # Assign at the end so that __getitem__ isn't confused
|
162 |
+
|
163 |
+
def __getitem__(self, idx: int):
|
164 |
+
graph, label = super().__getitem__(idx)
|
165 |
+
|
166 |
+
if self.bases:
|
167 |
+
bases_idx = idx // self.batch_size
|
168 |
+
bases_cumsum_idx = self.ne_cumsum[idx] - self.ne_cumsum[bases_idx * self.batch_size]
|
169 |
+
bases_cumsum_next_idx = self.ne_cumsum[idx + 1] - self.ne_cumsum[bases_idx * self.batch_size]
|
170 |
+
return graph, label, {key: basis[bases_cumsum_idx:bases_cumsum_next_idx] for key, basis in
|
171 |
+
self.bases[bases_idx].items()}
|
172 |
+
else:
|
173 |
+
return graph, label
|
env/SE3Transformer/build/lib/se3_transformer/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .transformer import SE3Transformer, SE3TransformerPooled
|
2 |
+
from .fiber import Fiber
|
env/SE3Transformer/build/lib/se3_transformer/model/basis.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
|
25 |
+
from functools import lru_cache
|
26 |
+
from typing import Dict, List
|
27 |
+
|
28 |
+
import e3nn.o3 as o3
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F
|
31 |
+
from torch import Tensor
|
32 |
+
from torch.cuda.nvtx import range as nvtx_range
|
33 |
+
|
34 |
+
from se3_transformer.runtime.utils import degree_to_dim
|
35 |
+
|
36 |
+
|
37 |
+
@lru_cache(maxsize=None)
|
38 |
+
def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor:
|
39 |
+
""" Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """
|
40 |
+
return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64, device=device).permute(2, 1, 0)
|
41 |
+
|
42 |
+
|
43 |
+
@lru_cache(maxsize=None)
|
44 |
+
def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
|
45 |
+
all_cb = []
|
46 |
+
for d_in in range(max_degree + 1):
|
47 |
+
for d_out in range(max_degree + 1):
|
48 |
+
K_Js = []
|
49 |
+
for J in range(abs(d_in - d_out), d_in + d_out + 1):
|
50 |
+
K_Js.append(get_clebsch_gordon(J, d_in, d_out, device))
|
51 |
+
all_cb.append(K_Js)
|
52 |
+
return all_cb
|
53 |
+
|
54 |
+
|
55 |
+
def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
|
56 |
+
all_degrees = list(range(2 * max_degree + 1))
|
57 |
+
with nvtx_range('spherical harmonics'):
|
58 |
+
sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
|
59 |
+
return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
|
60 |
+
|
61 |
+
|
62 |
+
@torch.jit.script
|
63 |
+
def get_basis_script(max_degree: int,
|
64 |
+
use_pad_trick: bool,
|
65 |
+
spherical_harmonics: List[Tensor],
|
66 |
+
clebsch_gordon: List[List[Tensor]],
|
67 |
+
amp: bool) -> Dict[str, Tensor]:
|
68 |
+
"""
|
69 |
+
Compute pairwise bases matrices for degrees up to max_degree
|
70 |
+
:param max_degree: Maximum input or output degree
|
71 |
+
:param use_pad_trick: Pad some of the odd dimensions for a better use of Tensor Cores
|
72 |
+
:param spherical_harmonics: List of computed spherical harmonics
|
73 |
+
:param clebsch_gordon: List of computed CB-coefficients
|
74 |
+
:param amp: When true, return bases in FP16 precision
|
75 |
+
"""
|
76 |
+
basis = {}
|
77 |
+
idx = 0
|
78 |
+
# Double for loop instead of product() because of JIT script
|
79 |
+
for d_in in range(max_degree + 1):
|
80 |
+
for d_out in range(max_degree + 1):
|
81 |
+
key = f'{d_in},{d_out}'
|
82 |
+
K_Js = []
|
83 |
+
for freq_idx, J in enumerate(range(abs(d_in - d_out), d_in + d_out + 1)):
|
84 |
+
Q_J = clebsch_gordon[idx][freq_idx]
|
85 |
+
K_Js.append(torch.einsum('n f, k l f -> n l k', spherical_harmonics[J].float(), Q_J.float()))
|
86 |
+
|
87 |
+
basis[key] = torch.stack(K_Js, 2) # Stack on second dim so order is n l f k
|
88 |
+
if amp:
|
89 |
+
basis[key] = basis[key].half()
|
90 |
+
if use_pad_trick:
|
91 |
+
basis[key] = F.pad(basis[key], (0, 1)) # Pad the k dimension, that can be sliced later
|
92 |
+
|
93 |
+
idx += 1
|
94 |
+
|
95 |
+
return basis
|
96 |
+
|
97 |
+
|
98 |
+
@torch.jit.script
|
99 |
+
def update_basis_with_fused(basis: Dict[str, Tensor],
|
100 |
+
max_degree: int,
|
101 |
+
use_pad_trick: bool,
|
102 |
+
fully_fused: bool) -> Dict[str, Tensor]:
|
103 |
+
""" Update the basis dict with partially and optionally fully fused bases """
|
104 |
+
num_edges = basis['0,0'].shape[0]
|
105 |
+
device = basis['0,0'].device
|
106 |
+
dtype = basis['0,0'].dtype
|
107 |
+
sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)])
|
108 |
+
|
109 |
+
# Fused per output degree
|
110 |
+
for d_out in range(max_degree + 1):
|
111 |
+
sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)])
|
112 |
+
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick),
|
113 |
+
device=device, dtype=dtype)
|
114 |
+
acc_d, acc_f = 0, 0
|
115 |
+
for d_in in range(max_degree + 1):
|
116 |
+
basis_fused[:, acc_d:acc_d + degree_to_dim(d_in), acc_f:acc_f + degree_to_dim(min(d_out, d_in)),
|
117 |
+
:degree_to_dim(d_out)] = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
|
118 |
+
|
119 |
+
acc_d += degree_to_dim(d_in)
|
120 |
+
acc_f += degree_to_dim(min(d_out, d_in))
|
121 |
+
|
122 |
+
basis[f'out{d_out}_fused'] = basis_fused
|
123 |
+
|
124 |
+
# Fused per input degree
|
125 |
+
for d_in in range(max_degree + 1):
|
126 |
+
sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)])
|
127 |
+
basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim,
|
128 |
+
device=device, dtype=dtype)
|
129 |
+
acc_d, acc_f = 0, 0
|
130 |
+
for d_out in range(max_degree + 1):
|
131 |
+
basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \
|
132 |
+
= basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
|
133 |
+
|
134 |
+
acc_d += degree_to_dim(d_out)
|
135 |
+
acc_f += degree_to_dim(min(d_out, d_in))
|
136 |
+
|
137 |
+
basis[f'in{d_in}_fused'] = basis_fused
|
138 |
+
|
139 |
+
if fully_fused:
|
140 |
+
# Fully fused
|
141 |
+
# Double sum this way because of JIT script
|
142 |
+
sum_freq = sum([
|
143 |
+
sum([degree_to_dim(min(d_in, d_out)) for d_in in range(max_degree + 1)]) for d_out in range(max_degree + 1)
|
144 |
+
])
|
145 |
+
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, sum_dim, device=device, dtype=dtype)
|
146 |
+
|
147 |
+
acc_d, acc_f = 0, 0
|
148 |
+
for d_out in range(max_degree + 1):
|
149 |
+
b = basis[f'out{d_out}_fused']
|
150 |
+
basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d + degree_to_dim(d_out)] = b[:, :, :,
|
151 |
+
:degree_to_dim(d_out)]
|
152 |
+
acc_f += b.shape[2]
|
153 |
+
acc_d += degree_to_dim(d_out)
|
154 |
+
|
155 |
+
basis['fully_fused'] = basis_fused
|
156 |
+
|
157 |
+
del basis['0,0'] # We know that the basis for l = k = 0 is filled with a constant
|
158 |
+
return basis
|
159 |
+
|
160 |
+
|
161 |
+
def get_basis(relative_pos: Tensor,
|
162 |
+
max_degree: int = 4,
|
163 |
+
compute_gradients: bool = False,
|
164 |
+
use_pad_trick: bool = False,
|
165 |
+
amp: bool = False) -> Dict[str, Tensor]:
|
166 |
+
with nvtx_range('spherical harmonics'):
|
167 |
+
spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
|
168 |
+
with nvtx_range('CB coefficients'):
|
169 |
+
clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device)
|
170 |
+
|
171 |
+
with torch.autograd.set_grad_enabled(compute_gradients):
|
172 |
+
with nvtx_range('bases'):
|
173 |
+
basis = get_basis_script(max_degree=max_degree,
|
174 |
+
use_pad_trick=use_pad_trick,
|
175 |
+
spherical_harmonics=spherical_harmonics,
|
176 |
+
clebsch_gordon=clebsch_gordon,
|
177 |
+
amp=amp)
|
178 |
+
return basis
|
env/SE3Transformer/build/lib/se3_transformer/model/fiber.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
|
25 |
+
from collections import namedtuple
|
26 |
+
from itertools import product
|
27 |
+
from typing import Dict
|
28 |
+
|
29 |
+
import torch
|
30 |
+
from torch import Tensor
|
31 |
+
|
32 |
+
from se3_transformer.runtime.utils import degree_to_dim
|
33 |
+
|
34 |
+
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
|
35 |
+
|
36 |
+
|
37 |
+
class Fiber(dict):
|
38 |
+
"""
|
39 |
+
Describes the structure of some set of features.
|
40 |
+
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
|
41 |
+
Type-0 features: invariant scalars
|
42 |
+
Type-1 features: equivariant 3D vectors
|
43 |
+
Type-2 features: equivariant symmetric traceless matrices
|
44 |
+
...
|
45 |
+
|
46 |
+
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
|
47 |
+
The 'multiplicity' or 'number of channels' is the number of features of a given type.
|
48 |
+
This class puts together all the degrees and their multiplicities in order to describe
|
49 |
+
the inputs, outputs or hidden features of SE3 layers.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, structure):
|
53 |
+
if isinstance(structure, dict):
|
54 |
+
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
|
55 |
+
elif not isinstance(structure[0], FiberEl):
|
56 |
+
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
|
57 |
+
self.structure = structure
|
58 |
+
super().__init__({d: m for d, m in self.structure})
|
59 |
+
|
60 |
+
@property
|
61 |
+
def degrees(self):
|
62 |
+
return sorted([t.degree for t in self.structure])
|
63 |
+
|
64 |
+
@property
|
65 |
+
def channels(self):
|
66 |
+
return [self[d] for d in self.degrees]
|
67 |
+
|
68 |
+
@property
|
69 |
+
def num_features(self):
|
70 |
+
""" Size of the resulting tensor if all features were concatenated together """
|
71 |
+
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def create(num_degrees: int, num_channels: int):
|
75 |
+
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
|
76 |
+
return Fiber([(degree, num_channels) for degree in range(num_degrees)])
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def from_features(feats: Dict[str, Tensor]):
|
80 |
+
""" Infer the Fiber structure from a feature dict """
|
81 |
+
structure = {}
|
82 |
+
for k, v in feats.items():
|
83 |
+
degree = int(k)
|
84 |
+
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
|
85 |
+
assert v.shape[-1] == degree_to_dim(degree)
|
86 |
+
structure[degree] = v.shape[-2]
|
87 |
+
return Fiber(structure)
|
88 |
+
|
89 |
+
def __getitem__(self, degree: int):
|
90 |
+
""" fiber[degree] returns the multiplicity for this degree """
|
91 |
+
return dict(self.structure).get(degree, 0)
|
92 |
+
|
93 |
+
def __iter__(self):
|
94 |
+
""" Iterate over namedtuples (degree, channels) """
|
95 |
+
return iter(self.structure)
|
96 |
+
|
97 |
+
def __mul__(self, other):
|
98 |
+
"""
|
99 |
+
If other in an int, multiplies all the multiplicities by other.
|
100 |
+
If other is a fiber, returns the cartesian product.
|
101 |
+
"""
|
102 |
+
if isinstance(other, Fiber):
|
103 |
+
return product(self.structure, other.structure)
|
104 |
+
elif isinstance(other, int):
|
105 |
+
return Fiber({t.degree: t.channels * other for t in self.structure})
|
106 |
+
|
107 |
+
def __add__(self, other):
|
108 |
+
"""
|
109 |
+
If other in an int, add other to all the multiplicities.
|
110 |
+
If other is a fiber, add the multiplicities of the fibers together.
|
111 |
+
"""
|
112 |
+
if isinstance(other, Fiber):
|
113 |
+
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
|
114 |
+
elif isinstance(other, int):
|
115 |
+
return Fiber({t.degree: t.channels + other for t in self.structure})
|
116 |
+
|
117 |
+
def __repr__(self):
|
118 |
+
return str(self.structure)
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def combine_max(f1, f2):
|
122 |
+
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
|
123 |
+
new_dict = dict(f1.structure)
|
124 |
+
for k, m in f2.structure:
|
125 |
+
new_dict[k] = max(new_dict.get(k, 0), m)
|
126 |
+
|
127 |
+
return Fiber(list(new_dict.items()))
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def combine_selectively(f1, f2):
|
131 |
+
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
|
132 |
+
# only use orders which occur in fiber f1
|
133 |
+
new_dict = dict(f1.structure)
|
134 |
+
for k in f1.degrees:
|
135 |
+
if k in f2.degrees:
|
136 |
+
new_dict[k] += f2[k]
|
137 |
+
return Fiber(list(new_dict.items()))
|
138 |
+
|
139 |
+
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
|
140 |
+
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
|
141 |
+
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
|
142 |
+
self.degrees]
|
143 |
+
fibers = torch.cat(fibers, -1)
|
144 |
+
return fibers
|
env/SE3Transformer/build/lib/se3_transformer/model/layers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .linear import LinearSE3
|
2 |
+
from .norm import NormSE3
|
3 |
+
from .pooling import GPooling
|
4 |
+
from .convolution import ConvSE3
|
5 |
+
from .attention import AttentionBlockSE3
|
env/SE3Transformer/build/lib/se3_transformer/model/layers/attention.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import dgl
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
from dgl import DGLGraph
|
29 |
+
from dgl.ops import edge_softmax
|
30 |
+
from torch import Tensor
|
31 |
+
from typing import Dict, Optional, Union
|
32 |
+
|
33 |
+
from se3_transformer.model.fiber import Fiber
|
34 |
+
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
35 |
+
from se3_transformer.model.layers.linear import LinearSE3
|
36 |
+
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
|
37 |
+
from torch.cuda.nvtx import range as nvtx_range
|
38 |
+
|
39 |
+
|
40 |
+
class AttentionSE3(nn.Module):
|
41 |
+
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
num_heads: int,
|
46 |
+
key_fiber: Fiber,
|
47 |
+
value_fiber: Fiber
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
:param num_heads: Number of attention heads
|
51 |
+
:param key_fiber: Fiber for the keys (and also for the queries)
|
52 |
+
:param value_fiber: Fiber for the values
|
53 |
+
"""
|
54 |
+
super().__init__()
|
55 |
+
self.num_heads = num_heads
|
56 |
+
self.key_fiber = key_fiber
|
57 |
+
self.value_fiber = value_fiber
|
58 |
+
|
59 |
+
def forward(
|
60 |
+
self,
|
61 |
+
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
62 |
+
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
63 |
+
query: Dict[str, Tensor], # node features
|
64 |
+
graph: DGLGraph
|
65 |
+
):
|
66 |
+
with nvtx_range('AttentionSE3'):
|
67 |
+
with nvtx_range('reshape keys and queries'):
|
68 |
+
if isinstance(key, Tensor):
|
69 |
+
# case where features of all types are fused
|
70 |
+
key = key.reshape(key.shape[0], self.num_heads, -1)
|
71 |
+
# need to reshape queries that way to keep the same layout as keys
|
72 |
+
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
|
73 |
+
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
|
74 |
+
else:
|
75 |
+
# features are not fused, need to fuse and reshape them
|
76 |
+
key = self.key_fiber.to_attention_heads(key, self.num_heads)
|
77 |
+
query = self.key_fiber.to_attention_heads(query, self.num_heads)
|
78 |
+
|
79 |
+
with nvtx_range('attention dot product + softmax'):
|
80 |
+
# Compute attention weights (softmax of inner product between key and query)
|
81 |
+
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
|
82 |
+
edge_weights /= np.sqrt(self.key_fiber.num_features)
|
83 |
+
edge_weights = edge_softmax(graph, edge_weights)
|
84 |
+
edge_weights = edge_weights[..., None, None]
|
85 |
+
|
86 |
+
with nvtx_range('weighted sum'):
|
87 |
+
if isinstance(value, Tensor):
|
88 |
+
# features of all types are fused
|
89 |
+
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
|
90 |
+
weights = edge_weights * v
|
91 |
+
feat_out = dgl.ops.copy_e_sum(graph, weights)
|
92 |
+
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
|
93 |
+
out = unfuse_features(feat_out, self.value_fiber.degrees)
|
94 |
+
else:
|
95 |
+
out = {}
|
96 |
+
for degree, channels in self.value_fiber:
|
97 |
+
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
|
98 |
+
degree_to_dim(degree))
|
99 |
+
weights = edge_weights * v
|
100 |
+
res = dgl.ops.copy_e_sum(graph, weights)
|
101 |
+
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
|
102 |
+
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
class AttentionBlockSE3(nn.Module):
|
107 |
+
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
fiber_in: Fiber,
|
112 |
+
fiber_out: Fiber,
|
113 |
+
fiber_edge: Optional[Fiber] = None,
|
114 |
+
num_heads: int = 4,
|
115 |
+
channels_div: int = 2,
|
116 |
+
use_layer_norm: bool = False,
|
117 |
+
max_degree: bool = 4,
|
118 |
+
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
119 |
+
**kwargs
|
120 |
+
):
|
121 |
+
"""
|
122 |
+
:param fiber_in: Fiber describing the input features
|
123 |
+
:param fiber_out: Fiber describing the output features
|
124 |
+
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
125 |
+
:param num_heads: Number of attention heads
|
126 |
+
:param channels_div: Divide the channels by this integer for computing values
|
127 |
+
:param use_layer_norm: Apply layer normalization between MLP layers
|
128 |
+
:param max_degree: Maximum degree used in the bases computation
|
129 |
+
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
130 |
+
"""
|
131 |
+
super().__init__()
|
132 |
+
if fiber_edge is None:
|
133 |
+
fiber_edge = Fiber({})
|
134 |
+
self.fiber_in = fiber_in
|
135 |
+
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
|
136 |
+
value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
|
137 |
+
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
|
138 |
+
# (queries are merely projected, hence degrees have to match input)
|
139 |
+
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
|
140 |
+
|
141 |
+
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
|
142 |
+
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
|
143 |
+
allow_fused_output=True)
|
144 |
+
self.to_query = LinearSE3(fiber_in, key_query_fiber)
|
145 |
+
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
|
146 |
+
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
|
147 |
+
|
148 |
+
def forward(
|
149 |
+
self,
|
150 |
+
node_features: Dict[str, Tensor],
|
151 |
+
edge_features: Dict[str, Tensor],
|
152 |
+
graph: DGLGraph,
|
153 |
+
basis: Dict[str, Tensor]
|
154 |
+
):
|
155 |
+
with nvtx_range('AttentionBlockSE3'):
|
156 |
+
with nvtx_range('keys / values'):
|
157 |
+
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
|
158 |
+
key, value = self._get_key_value_from_fused(fused_key_value)
|
159 |
+
|
160 |
+
with nvtx_range('queries'):
|
161 |
+
query = self.to_query(node_features)
|
162 |
+
|
163 |
+
z = self.attention(value, key, query, graph)
|
164 |
+
z_concat = aggregate_residual(node_features, z, 'cat')
|
165 |
+
return self.project(z_concat)
|
166 |
+
|
167 |
+
def _get_key_value_from_fused(self, fused_key_value):
|
168 |
+
# Extract keys and queries features from fused features
|
169 |
+
if isinstance(fused_key_value, Tensor):
|
170 |
+
# Previous layer was a fully fused convolution
|
171 |
+
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
|
172 |
+
else:
|
173 |
+
key, value = {}, {}
|
174 |
+
for degree, feat in fused_key_value.items():
|
175 |
+
if int(degree) in self.fiber_in.degrees:
|
176 |
+
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
|
177 |
+
else:
|
178 |
+
value[degree] = feat
|
179 |
+
|
180 |
+
return key, value
|
env/SE3Transformer/build/lib/se3_transformer/model/layers/convolution.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
from enum import Enum
|
25 |
+
from itertools import product
|
26 |
+
from typing import Dict
|
27 |
+
|
28 |
+
import dgl
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
import torch.nn as nn
|
32 |
+
from dgl import DGLGraph
|
33 |
+
from torch import Tensor
|
34 |
+
from torch.cuda.nvtx import range as nvtx_range
|
35 |
+
|
36 |
+
from se3_transformer.model.fiber import Fiber
|
37 |
+
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
|
38 |
+
|
39 |
+
|
40 |
+
class ConvSE3FuseLevel(Enum):
|
41 |
+
"""
|
42 |
+
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
|
43 |
+
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
|
44 |
+
A higher level means faster training, but also more memory usage.
|
45 |
+
If you are tight on memory and want to feed large inputs to the network, choose a low value.
|
46 |
+
If you want to train fast, choose a high value.
|
47 |
+
Recommended value is FULL with AMP.
|
48 |
+
|
49 |
+
Fully fused TFN convolutions requirements:
|
50 |
+
- all input channels are the same
|
51 |
+
- all output channels are the same
|
52 |
+
- input degrees span the range [0, ..., max_degree]
|
53 |
+
- output degrees span the range [0, ..., max_degree]
|
54 |
+
|
55 |
+
Partially fused TFN convolutions requirements:
|
56 |
+
* For fusing by output degree:
|
57 |
+
- all input channels are the same
|
58 |
+
- input degrees span the range [0, ..., max_degree]
|
59 |
+
* For fusing by input degree:
|
60 |
+
- all output channels are the same
|
61 |
+
- output degrees span the range [0, ..., max_degree]
|
62 |
+
|
63 |
+
Original TFN pairwise convolutions: no requirements
|
64 |
+
"""
|
65 |
+
|
66 |
+
FULL = 2
|
67 |
+
PARTIAL = 1
|
68 |
+
NONE = 0
|
69 |
+
|
70 |
+
|
71 |
+
class RadialProfile(nn.Module):
|
72 |
+
"""
|
73 |
+
Radial profile function.
|
74 |
+
Outputs weights used to weigh basis matrices in order to get convolution kernels.
|
75 |
+
In TFN notation: $R^{l,k}$
|
76 |
+
In SE(3)-Transformer notation: $\phi^{l,k}$
|
77 |
+
|
78 |
+
Note:
|
79 |
+
In the original papers, this function only depends on relative node distances ||x||.
|
80 |
+
Here, we allow this function to also take as input additional invariant edge features.
|
81 |
+
This does not break equivariance and adds expressive power to the model.
|
82 |
+
|
83 |
+
Diagram:
|
84 |
+
invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
num_freq: int,
|
90 |
+
channels_in: int,
|
91 |
+
channels_out: int,
|
92 |
+
edge_dim: int = 1,
|
93 |
+
mid_dim: int = 32,
|
94 |
+
use_layer_norm: bool = False
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
:param num_freq: Number of frequencies
|
98 |
+
:param channels_in: Number of input channels
|
99 |
+
:param channels_out: Number of output channels
|
100 |
+
:param edge_dim: Number of invariant edge features (input to the radial function)
|
101 |
+
:param mid_dim: Size of the hidden MLP layers
|
102 |
+
:param use_layer_norm: Apply layer normalization between MLP layers
|
103 |
+
"""
|
104 |
+
super().__init__()
|
105 |
+
modules = [
|
106 |
+
nn.Linear(edge_dim, mid_dim),
|
107 |
+
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
108 |
+
nn.ReLU(),
|
109 |
+
nn.Linear(mid_dim, mid_dim),
|
110 |
+
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
111 |
+
nn.ReLU(),
|
112 |
+
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
|
113 |
+
]
|
114 |
+
|
115 |
+
self.net = nn.Sequential(*[m for m in modules if m is not None])
|
116 |
+
|
117 |
+
def forward(self, features: Tensor) -> Tensor:
|
118 |
+
return self.net(features)
|
119 |
+
|
120 |
+
|
121 |
+
class VersatileConvSE3(nn.Module):
|
122 |
+
"""
|
123 |
+
Building block for TFN convolutions.
|
124 |
+
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(self,
|
128 |
+
freq_sum: int,
|
129 |
+
channels_in: int,
|
130 |
+
channels_out: int,
|
131 |
+
edge_dim: int,
|
132 |
+
use_layer_norm: bool,
|
133 |
+
fuse_level: ConvSE3FuseLevel):
|
134 |
+
super().__init__()
|
135 |
+
self.freq_sum = freq_sum
|
136 |
+
self.channels_out = channels_out
|
137 |
+
self.channels_in = channels_in
|
138 |
+
self.fuse_level = fuse_level
|
139 |
+
self.radial_func = RadialProfile(num_freq=freq_sum,
|
140 |
+
channels_in=channels_in,
|
141 |
+
channels_out=channels_out,
|
142 |
+
edge_dim=edge_dim,
|
143 |
+
use_layer_norm=use_layer_norm)
|
144 |
+
|
145 |
+
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
|
146 |
+
with nvtx_range(f'VersatileConvSE3'):
|
147 |
+
num_edges = features.shape[0]
|
148 |
+
in_dim = features.shape[2]
|
149 |
+
with nvtx_range(f'RadialProfile'):
|
150 |
+
radial_weights = self.radial_func(invariant_edge_feats) \
|
151 |
+
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
|
152 |
+
|
153 |
+
if basis is not None:
|
154 |
+
# This block performs the einsum n i l, n o i f, n l f k -> n o k
|
155 |
+
out_dim = basis.shape[-1]
|
156 |
+
if self.fuse_level != ConvSE3FuseLevel.FULL:
|
157 |
+
out_dim += out_dim % 2 - 1 # Account for padded basis
|
158 |
+
basis_view = basis.view(num_edges, in_dim, -1)
|
159 |
+
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
|
160 |
+
return (radial_weights @ tmp)[:, :, :out_dim]
|
161 |
+
else:
|
162 |
+
# k = l = 0 non-fused case
|
163 |
+
return radial_weights @ features
|
164 |
+
|
165 |
+
|
166 |
+
class ConvSE3(nn.Module):
|
167 |
+
"""
|
168 |
+
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
|
169 |
+
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
|
170 |
+
Features of different degrees interact together to produce output features.
|
171 |
+
|
172 |
+
Note 1:
|
173 |
+
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
|
174 |
+
done, and the returned features will be edge features instead of node features.
|
175 |
+
|
176 |
+
Note 2:
|
177 |
+
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
|
178 |
+
Input edge features are concatenated with input source node features before the kernel is applied.
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
fiber_in: Fiber,
|
184 |
+
fiber_out: Fiber,
|
185 |
+
fiber_edge: Fiber,
|
186 |
+
pool: bool = True,
|
187 |
+
use_layer_norm: bool = False,
|
188 |
+
self_interaction: bool = False,
|
189 |
+
max_degree: int = 4,
|
190 |
+
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
191 |
+
allow_fused_output: bool = False
|
192 |
+
):
|
193 |
+
"""
|
194 |
+
:param fiber_in: Fiber describing the input features
|
195 |
+
:param fiber_out: Fiber describing the output features
|
196 |
+
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
197 |
+
:param pool: If True, compute final node features by averaging incoming edge features
|
198 |
+
:param use_layer_norm: Apply layer normalization between MLP layers
|
199 |
+
:param self_interaction: Apply self-interaction of nodes
|
200 |
+
:param max_degree: Maximum degree used in the bases computation
|
201 |
+
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
202 |
+
:param allow_fused_output: Allow the module to output a fused representation of features
|
203 |
+
"""
|
204 |
+
super().__init__()
|
205 |
+
self.pool = pool
|
206 |
+
self.fiber_in = fiber_in
|
207 |
+
self.fiber_out = fiber_out
|
208 |
+
self.self_interaction = self_interaction
|
209 |
+
self.max_degree = max_degree
|
210 |
+
self.allow_fused_output = allow_fused_output
|
211 |
+
|
212 |
+
# channels_in: account for the concatenation of edge features
|
213 |
+
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
|
214 |
+
channels_out_set = set([f.channels for f in self.fiber_out])
|
215 |
+
unique_channels_in = (len(channels_in_set) == 1)
|
216 |
+
unique_channels_out = (len(channels_out_set) == 1)
|
217 |
+
degrees_up_to_max = list(range(max_degree + 1))
|
218 |
+
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
|
219 |
+
|
220 |
+
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
|
221 |
+
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
|
222 |
+
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
223 |
+
# Single fused convolution
|
224 |
+
self.used_fuse_level = ConvSE3FuseLevel.FULL
|
225 |
+
|
226 |
+
sum_freq = sum([
|
227 |
+
degree_to_dim(min(d_in, d_out))
|
228 |
+
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
|
229 |
+
])
|
230 |
+
|
231 |
+
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
|
232 |
+
fuse_level=self.used_fuse_level, **common_args)
|
233 |
+
|
234 |
+
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
235 |
+
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
|
236 |
+
# Convolutions fused per output degree
|
237 |
+
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
238 |
+
self.conv_out = nn.ModuleDict()
|
239 |
+
for d_out, c_out in fiber_out:
|
240 |
+
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
|
241 |
+
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
|
242 |
+
fuse_level=self.used_fuse_level, **common_args)
|
243 |
+
|
244 |
+
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
245 |
+
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
246 |
+
# Convolutions fused per input degree
|
247 |
+
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
248 |
+
self.conv_in = nn.ModuleDict()
|
249 |
+
for d_in, c_in in fiber_in:
|
250 |
+
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
|
251 |
+
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0],
|
252 |
+
fuse_level=ConvSE3FuseLevel.FULL, **common_args)
|
253 |
+
#fuse_level=self.used_fuse_level, **common_args)
|
254 |
+
else:
|
255 |
+
# Use pairwise TFN convolutions
|
256 |
+
self.used_fuse_level = ConvSE3FuseLevel.NONE
|
257 |
+
self.conv = nn.ModuleDict()
|
258 |
+
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
|
259 |
+
dict_key = f'{degree_in},{degree_out}'
|
260 |
+
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
|
261 |
+
sum_freq = degree_to_dim(min(degree_in, degree_out))
|
262 |
+
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
|
263 |
+
fuse_level=self.used_fuse_level, **common_args)
|
264 |
+
|
265 |
+
if self_interaction:
|
266 |
+
self.to_kernel_self = nn.ParameterDict()
|
267 |
+
for degree_out, channels_out in fiber_out:
|
268 |
+
if fiber_in[degree_out]:
|
269 |
+
self.to_kernel_self[str(degree_out)] = nn.Parameter(
|
270 |
+
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
271 |
+
|
272 |
+
def forward(
|
273 |
+
self,
|
274 |
+
node_feats: Dict[str, Tensor],
|
275 |
+
edge_feats: Dict[str, Tensor],
|
276 |
+
graph: DGLGraph,
|
277 |
+
basis: Dict[str, Tensor]
|
278 |
+
):
|
279 |
+
with nvtx_range(f'ConvSE3'):
|
280 |
+
invariant_edge_feats = edge_feats['0'].squeeze(-1)
|
281 |
+
src, dst = graph.edges()
|
282 |
+
out = {}
|
283 |
+
in_features = []
|
284 |
+
|
285 |
+
# Fetch all input features from edge and node features
|
286 |
+
for degree_in in self.fiber_in.degrees:
|
287 |
+
src_node_features = node_feats[str(degree_in)][src]
|
288 |
+
if degree_in > 0 and str(degree_in) in edge_feats:
|
289 |
+
# Handle edge features of any type by concatenating them to node features
|
290 |
+
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
|
291 |
+
in_features.append(src_node_features)
|
292 |
+
|
293 |
+
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
|
294 |
+
in_features_fused = torch.cat(in_features, dim=-1)
|
295 |
+
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
|
296 |
+
|
297 |
+
if not self.allow_fused_output or self.self_interaction or self.pool:
|
298 |
+
out = unfuse_features(out, self.fiber_out.degrees)
|
299 |
+
|
300 |
+
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
|
301 |
+
in_features_fused = torch.cat(in_features, dim=-1)
|
302 |
+
for degree_out in self.fiber_out.degrees:
|
303 |
+
out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats,
|
304 |
+
basis[f'out{degree_out}_fused'])
|
305 |
+
|
306 |
+
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
|
307 |
+
out = 0
|
308 |
+
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
309 |
+
out += self.conv_in[str(degree_in)](feature, invariant_edge_feats,
|
310 |
+
basis[f'in{degree_in}_fused'])
|
311 |
+
if not self.allow_fused_output or self.self_interaction or self.pool:
|
312 |
+
out = unfuse_features(out, self.fiber_out.degrees)
|
313 |
+
else:
|
314 |
+
# Fallback to pairwise TFN convolutions
|
315 |
+
for degree_out in self.fiber_out.degrees:
|
316 |
+
out_feature = 0
|
317 |
+
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
318 |
+
dict_key = f'{degree_in},{degree_out}'
|
319 |
+
out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats,
|
320 |
+
basis.get(dict_key, None))
|
321 |
+
out[str(degree_out)] = out_feature
|
322 |
+
|
323 |
+
for degree_out in self.fiber_out.degrees:
|
324 |
+
if self.self_interaction and str(degree_out) in self.to_kernel_self:
|
325 |
+
with nvtx_range(f'self interaction'):
|
326 |
+
dst_features = node_feats[str(degree_out)][dst]
|
327 |
+
kernel_self = self.to_kernel_self[str(degree_out)]
|
328 |
+
out[str(degree_out)] += kernel_self @ dst_features
|
329 |
+
|
330 |
+
if self.pool:
|
331 |
+
with nvtx_range(f'pooling'):
|
332 |
+
if isinstance(out, dict):
|
333 |
+
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
|
334 |
+
else:
|
335 |
+
out = dgl.ops.copy_e_sum(graph, out)
|
336 |
+
return out
|
env/SE3Transformer/build/lib/se3_transformer/model/layers/linear.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
|
25 |
+
from typing import Dict
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
import torch.nn as nn
|
30 |
+
from torch import Tensor
|
31 |
+
|
32 |
+
from se3_transformer.model.fiber import Fiber
|
33 |
+
|
34 |
+
|
35 |
+
class LinearSE3(nn.Module):
|
36 |
+
"""
|
37 |
+
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
|
38 |
+
Maps a fiber to a fiber with the same degrees (channels may be different).
|
39 |
+
No interaction between degrees, but interaction between channels.
|
40 |
+
|
41 |
+
type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels)
|
42 |
+
type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels)
|
43 |
+
:
|
44 |
+
type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels)
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
|
48 |
+
super().__init__()
|
49 |
+
self.weights = nn.ParameterDict({
|
50 |
+
str(degree_out): nn.Parameter(
|
51 |
+
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
52 |
+
for degree_out, channels_out in fiber_out
|
53 |
+
})
|
54 |
+
|
55 |
+
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
56 |
+
return {
|
57 |
+
degree: self.weights[degree] @ features[degree]
|
58 |
+
for degree, weight in self.weights.items()
|
59 |
+
}
|
env/SE3Transformer/build/lib/se3_transformer/model/layers/norm.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
|
25 |
+
from typing import Dict
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
from torch import Tensor
|
30 |
+
from torch.cuda.nvtx import range as nvtx_range
|
31 |
+
|
32 |
+
from se3_transformer.model.fiber import Fiber
|
33 |
+
|
34 |
+
|
35 |
+
class NormSE3(nn.Module):
|
36 |
+
"""
|
37 |
+
Norm-based SE(3)-equivariant nonlinearity.
|
38 |
+
|
39 |
+
┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐
|
40 |
+
feature_in ──┤ * ──> feature_out
|
41 |
+
└──> feature_phase ────────────────────────────┘
|
42 |
+
"""
|
43 |
+
|
44 |
+
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
|
45 |
+
|
46 |
+
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
|
47 |
+
super().__init__()
|
48 |
+
self.fiber = fiber
|
49 |
+
self.nonlinearity = nonlinearity
|
50 |
+
|
51 |
+
if len(set(fiber.channels)) == 1:
|
52 |
+
# Fuse all the layer normalizations into a group normalization
|
53 |
+
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
|
54 |
+
else:
|
55 |
+
# Use multiple layer normalizations
|
56 |
+
self.layer_norms = nn.ModuleDict({
|
57 |
+
str(degree): nn.LayerNorm(channels)
|
58 |
+
for degree, channels in fiber
|
59 |
+
})
|
60 |
+
|
61 |
+
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
62 |
+
with nvtx_range('NormSE3'):
|
63 |
+
output = {}
|
64 |
+
if hasattr(self, 'group_norm'):
|
65 |
+
# Compute per-degree norms of features
|
66 |
+
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
67 |
+
for d in self.fiber.degrees]
|
68 |
+
fused_norms = torch.cat(norms, dim=-2)
|
69 |
+
|
70 |
+
# Transform the norms only
|
71 |
+
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
|
72 |
+
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
|
73 |
+
|
74 |
+
# Scale features to the new norms
|
75 |
+
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
|
76 |
+
output[str(d)] = features[str(d)] / norm * new_norm
|
77 |
+
else:
|
78 |
+
for degree, feat in features.items():
|
79 |
+
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
80 |
+
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
|
81 |
+
output[degree] = new_norm * feat / norm
|
82 |
+
|
83 |
+
return output
|
env/SE3Transformer/build/lib/se3_transformer/model/layers/pooling.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
from typing import Dict, Literal
|
25 |
+
|
26 |
+
import torch.nn as nn
|
27 |
+
from dgl import DGLGraph
|
28 |
+
from dgl.nn.pytorch import AvgPooling, MaxPooling
|
29 |
+
from torch import Tensor
|
30 |
+
|
31 |
+
|
32 |
+
class GPooling(nn.Module):
|
33 |
+
"""
|
34 |
+
Graph max/average pooling on a given feature type.
|
35 |
+
The average can be taken for any feature type, and equivariance will be maintained.
|
36 |
+
The maximum can only be taken for invariant features (type 0).
|
37 |
+
If you want max-pooling for type > 0 features, look into Vector Neurons.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
|
41 |
+
"""
|
42 |
+
:param feat_type: Feature type to pool
|
43 |
+
:param pool: Type of pooling: max or avg
|
44 |
+
"""
|
45 |
+
super().__init__()
|
46 |
+
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
|
47 |
+
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
|
48 |
+
self.feat_type = feat_type
|
49 |
+
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
|
50 |
+
|
51 |
+
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
|
52 |
+
pooled = self.pool(graph, features[str(self.feat_type)])
|
53 |
+
return pooled.squeeze(dim=-1)
|
env/SE3Transformer/build/lib/se3_transformer/model/transformer.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import logging
|
25 |
+
from typing import Optional, Literal, Dict
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
from dgl import DGLGraph
|
30 |
+
from torch import Tensor
|
31 |
+
|
32 |
+
from se3_transformer.model.basis import get_basis, update_basis_with_fused
|
33 |
+
from se3_transformer.model.layers.attention import AttentionBlockSE3
|
34 |
+
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
35 |
+
from se3_transformer.model.layers.norm import NormSE3
|
36 |
+
from se3_transformer.model.layers.pooling import GPooling
|
37 |
+
from se3_transformer.runtime.utils import str2bool
|
38 |
+
from se3_transformer.model.fiber import Fiber
|
39 |
+
|
40 |
+
|
41 |
+
class Sequential(nn.Sequential):
|
42 |
+
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
|
43 |
+
|
44 |
+
def forward(self, input, *args, **kwargs):
|
45 |
+
for module in self:
|
46 |
+
input = module(input, *args, **kwargs)
|
47 |
+
return input
|
48 |
+
|
49 |
+
|
50 |
+
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
|
51 |
+
""" Add relative positions to existing edge features """
|
52 |
+
edge_features = edge_features.copy() if edge_features else {}
|
53 |
+
r = relative_pos.norm(dim=-1, keepdim=True)
|
54 |
+
if '0' in edge_features:
|
55 |
+
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
|
56 |
+
else:
|
57 |
+
edge_features['0'] = r[..., None]
|
58 |
+
|
59 |
+
return edge_features
|
60 |
+
|
61 |
+
|
62 |
+
class SE3Transformer(nn.Module):
|
63 |
+
def __init__(self,
|
64 |
+
num_layers: int,
|
65 |
+
fiber_in: Fiber,
|
66 |
+
fiber_hidden: Fiber,
|
67 |
+
fiber_out: Fiber,
|
68 |
+
num_heads: int,
|
69 |
+
channels_div: int,
|
70 |
+
fiber_edge: Fiber = Fiber({}),
|
71 |
+
return_type: Optional[int] = None,
|
72 |
+
pooling: Optional[Literal['avg', 'max']] = None,
|
73 |
+
norm: bool = True,
|
74 |
+
use_layer_norm: bool = True,
|
75 |
+
tensor_cores: bool = False,
|
76 |
+
low_memory: bool = False,
|
77 |
+
**kwargs):
|
78 |
+
"""
|
79 |
+
:param num_layers: Number of attention layers
|
80 |
+
:param fiber_in: Input fiber description
|
81 |
+
:param fiber_hidden: Hidden fiber description
|
82 |
+
:param fiber_out: Output fiber description
|
83 |
+
:param fiber_edge: Input edge fiber description
|
84 |
+
:param num_heads: Number of attention heads
|
85 |
+
:param channels_div: Channels division before feeding to attention layer
|
86 |
+
:param return_type: Return only features of this type
|
87 |
+
:param pooling: 'avg' or 'max' graph pooling before MLP layers
|
88 |
+
:param norm: Apply a normalization layer after each attention block
|
89 |
+
:param use_layer_norm: Apply layer normalization between MLP layers
|
90 |
+
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
|
91 |
+
:param low_memory: If True, will use slower ops that use less memory
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
self.num_layers = num_layers
|
95 |
+
self.fiber_edge = fiber_edge
|
96 |
+
self.num_heads = num_heads
|
97 |
+
self.channels_div = channels_div
|
98 |
+
self.return_type = return_type
|
99 |
+
self.pooling = pooling
|
100 |
+
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
|
101 |
+
self.tensor_cores = tensor_cores
|
102 |
+
self.low_memory = low_memory
|
103 |
+
|
104 |
+
if low_memory and not tensor_cores:
|
105 |
+
logging.warning('Low memory mode will have no effect with no Tensor Cores')
|
106 |
+
|
107 |
+
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
|
108 |
+
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
|
109 |
+
|
110 |
+
graph_modules = []
|
111 |
+
for i in range(num_layers):
|
112 |
+
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
|
113 |
+
fiber_out=fiber_hidden,
|
114 |
+
fiber_edge=fiber_edge,
|
115 |
+
num_heads=num_heads,
|
116 |
+
channels_div=channels_div,
|
117 |
+
use_layer_norm=use_layer_norm,
|
118 |
+
max_degree=self.max_degree,
|
119 |
+
fuse_level=fuse_level))
|
120 |
+
if norm:
|
121 |
+
graph_modules.append(NormSE3(fiber_hidden))
|
122 |
+
fiber_in = fiber_hidden
|
123 |
+
|
124 |
+
graph_modules.append(ConvSE3(fiber_in=fiber_in,
|
125 |
+
fiber_out=fiber_out,
|
126 |
+
fiber_edge=fiber_edge,
|
127 |
+
self_interaction=True,
|
128 |
+
use_layer_norm=use_layer_norm,
|
129 |
+
max_degree=self.max_degree))
|
130 |
+
self.graph_modules = Sequential(*graph_modules)
|
131 |
+
|
132 |
+
if pooling is not None:
|
133 |
+
assert return_type is not None, 'return_type must be specified when pooling'
|
134 |
+
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
|
135 |
+
|
136 |
+
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
|
137 |
+
edge_feats: Optional[Dict[str, Tensor]] = None,
|
138 |
+
basis: Optional[Dict[str, Tensor]] = None):
|
139 |
+
# Compute bases in case they weren't precomputed as part of the data loading
|
140 |
+
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
|
141 |
+
use_pad_trick=self.tensor_cores and not self.low_memory,
|
142 |
+
amp=torch.is_autocast_enabled())
|
143 |
+
|
144 |
+
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
|
145 |
+
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
|
146 |
+
fully_fused=self.tensor_cores and not self.low_memory)
|
147 |
+
|
148 |
+
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
|
149 |
+
|
150 |
+
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
|
151 |
+
|
152 |
+
if self.pooling is not None:
|
153 |
+
return self.pooling_module(node_feats, graph=graph)
|
154 |
+
|
155 |
+
if self.return_type is not None:
|
156 |
+
return node_feats[str(self.return_type)]
|
157 |
+
|
158 |
+
return node_feats
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def add_argparse_args(parser):
|
162 |
+
parser.add_argument('--num_layers', type=int, default=7,
|
163 |
+
help='Number of stacked Transformer layers')
|
164 |
+
parser.add_argument('--num_heads', type=int, default=8,
|
165 |
+
help='Number of heads in self-attention')
|
166 |
+
parser.add_argument('--channels_div', type=int, default=2,
|
167 |
+
help='Channels division before feeding to attention layer')
|
168 |
+
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
|
169 |
+
help='Type of graph pooling')
|
170 |
+
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
|
171 |
+
help='Apply a normalization layer after each attention block')
|
172 |
+
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
|
173 |
+
help='Apply layer normalization between MLP layers')
|
174 |
+
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
|
175 |
+
help='If true, will use fused ops that are slower but that use less memory '
|
176 |
+
'(expect 25 percent less memory). '
|
177 |
+
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
|
178 |
+
|
179 |
+
return parser
|
180 |
+
|
181 |
+
|
182 |
+
class SE3TransformerPooled(nn.Module):
|
183 |
+
def __init__(self,
|
184 |
+
fiber_in: Fiber,
|
185 |
+
fiber_out: Fiber,
|
186 |
+
fiber_edge: Fiber,
|
187 |
+
num_degrees: int,
|
188 |
+
num_channels: int,
|
189 |
+
output_dim: int,
|
190 |
+
**kwargs):
|
191 |
+
super().__init__()
|
192 |
+
kwargs['pooling'] = kwargs['pooling'] or 'max'
|
193 |
+
self.transformer = SE3Transformer(
|
194 |
+
fiber_in=fiber_in,
|
195 |
+
fiber_hidden=Fiber.create(num_degrees, num_channels),
|
196 |
+
fiber_out=fiber_out,
|
197 |
+
fiber_edge=fiber_edge,
|
198 |
+
return_type=0,
|
199 |
+
**kwargs
|
200 |
+
)
|
201 |
+
|
202 |
+
n_out_features = fiber_out.num_features
|
203 |
+
self.mlp = nn.Sequential(
|
204 |
+
nn.Linear(n_out_features, n_out_features),
|
205 |
+
nn.ReLU(),
|
206 |
+
nn.Linear(n_out_features, output_dim)
|
207 |
+
)
|
208 |
+
|
209 |
+
def forward(self, graph, node_feats, edge_feats, basis=None):
|
210 |
+
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
|
211 |
+
y = self.mlp(feats).squeeze(-1)
|
212 |
+
return y
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
def add_argparse_args(parent_parser):
|
216 |
+
parser = parent_parser.add_argument_group("Model architecture")
|
217 |
+
SE3Transformer.add_argparse_args(parser)
|
218 |
+
parser.add_argument('--num_degrees',
|
219 |
+
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
|
220 |
+
type=int, default=4)
|
221 |
+
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
|
222 |
+
return parent_parser
|
env/SE3Transformer/build/lib/se3_transformer/runtime/__init__.py
ADDED
File without changes
|
env/SE3Transformer/build/lib/se3_transformer/runtime/arguments.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import argparse
|
25 |
+
import pathlib
|
26 |
+
|
27 |
+
from se3_transformer.data_loading import QM9DataModule
|
28 |
+
from se3_transformer.model import SE3TransformerPooled
|
29 |
+
from se3_transformer.runtime.utils import str2bool
|
30 |
+
|
31 |
+
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
|
32 |
+
|
33 |
+
paths = PARSER.add_argument_group('Paths')
|
34 |
+
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
|
35 |
+
help='Directory where the data is located or should be downloaded')
|
36 |
+
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
|
37 |
+
help='Directory where the results logs should be saved')
|
38 |
+
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
|
39 |
+
help='Name for the resulting DLLogger JSON file')
|
40 |
+
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
|
41 |
+
help='File where the checkpoint should be saved')
|
42 |
+
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
|
43 |
+
help='File of the checkpoint to be loaded')
|
44 |
+
|
45 |
+
optimizer = PARSER.add_argument_group('Optimizer')
|
46 |
+
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
|
47 |
+
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
|
48 |
+
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
|
49 |
+
optimizer.add_argument('--momentum', type=float, default=0.9)
|
50 |
+
optimizer.add_argument('--weight_decay', type=float, default=0.1)
|
51 |
+
|
52 |
+
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
|
53 |
+
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
|
54 |
+
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
|
55 |
+
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
|
56 |
+
|
57 |
+
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
|
58 |
+
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
|
59 |
+
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
|
60 |
+
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
|
61 |
+
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
|
62 |
+
help='Do an evaluation round every N epochs')
|
63 |
+
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
|
64 |
+
help='Minimize stdout output')
|
65 |
+
|
66 |
+
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
|
67 |
+
help='Benchmark mode')
|
68 |
+
|
69 |
+
QM9DataModule.add_argparse_args(PARSER)
|
70 |
+
SE3TransformerPooled.add_argparse_args(PARSER)
|
env/SE3Transformer/build/lib/se3_transformer/runtime/callbacks.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import time
|
26 |
+
from abc import ABC, abstractmethod
|
27 |
+
from typing import Optional
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
|
32 |
+
from se3_transformer.runtime.loggers import Logger
|
33 |
+
from se3_transformer.runtime.metrics import MeanAbsoluteError
|
34 |
+
|
35 |
+
|
36 |
+
class BaseCallback(ABC):
|
37 |
+
def on_fit_start(self, optimizer, args):
|
38 |
+
pass
|
39 |
+
|
40 |
+
def on_fit_end(self):
|
41 |
+
pass
|
42 |
+
|
43 |
+
def on_epoch_end(self):
|
44 |
+
pass
|
45 |
+
|
46 |
+
def on_batch_start(self):
|
47 |
+
pass
|
48 |
+
|
49 |
+
def on_validation_step(self, input, target, pred):
|
50 |
+
pass
|
51 |
+
|
52 |
+
def on_validation_end(self, epoch=None):
|
53 |
+
pass
|
54 |
+
|
55 |
+
def on_checkpoint_load(self, checkpoint):
|
56 |
+
pass
|
57 |
+
|
58 |
+
def on_checkpoint_save(self, checkpoint):
|
59 |
+
pass
|
60 |
+
|
61 |
+
|
62 |
+
class LRSchedulerCallback(BaseCallback):
|
63 |
+
def __init__(self, logger: Optional[Logger] = None):
|
64 |
+
self.logger = logger
|
65 |
+
self.scheduler = None
|
66 |
+
|
67 |
+
@abstractmethod
|
68 |
+
def get_scheduler(self, optimizer, args):
|
69 |
+
pass
|
70 |
+
|
71 |
+
def on_fit_start(self, optimizer, args):
|
72 |
+
self.scheduler = self.get_scheduler(optimizer, args)
|
73 |
+
|
74 |
+
def on_checkpoint_load(self, checkpoint):
|
75 |
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
76 |
+
|
77 |
+
def on_checkpoint_save(self, checkpoint):
|
78 |
+
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
|
79 |
+
|
80 |
+
def on_epoch_end(self):
|
81 |
+
if self.logger is not None:
|
82 |
+
self.logger.log_metrics({'learning rate': self.scheduler.get_last_lr()[0]}, step=self.scheduler.last_epoch)
|
83 |
+
self.scheduler.step()
|
84 |
+
|
85 |
+
|
86 |
+
class QM9MetricCallback(BaseCallback):
|
87 |
+
""" Logs the rescaled mean absolute error for QM9 regression tasks """
|
88 |
+
|
89 |
+
def __init__(self, logger, targets_std, prefix=''):
|
90 |
+
self.mae = MeanAbsoluteError()
|
91 |
+
self.logger = logger
|
92 |
+
self.targets_std = targets_std
|
93 |
+
self.prefix = prefix
|
94 |
+
self.best_mae = float('inf')
|
95 |
+
|
96 |
+
def on_validation_step(self, input, target, pred):
|
97 |
+
self.mae(pred.detach(), target.detach())
|
98 |
+
|
99 |
+
def on_validation_end(self, epoch=None):
|
100 |
+
mae = self.mae.compute() * self.targets_std
|
101 |
+
logging.info(f'{self.prefix} MAE: {mae}')
|
102 |
+
self.logger.log_metrics({f'{self.prefix} MAE': mae}, epoch)
|
103 |
+
self.best_mae = min(self.best_mae, mae)
|
104 |
+
|
105 |
+
def on_fit_end(self):
|
106 |
+
if self.best_mae != float('inf'):
|
107 |
+
self.logger.log_metrics({f'{self.prefix} best MAE': self.best_mae})
|
108 |
+
|
109 |
+
|
110 |
+
class QM9LRSchedulerCallback(LRSchedulerCallback):
|
111 |
+
def __init__(self, logger, epochs):
|
112 |
+
super().__init__(logger)
|
113 |
+
self.epochs = epochs
|
114 |
+
|
115 |
+
def get_scheduler(self, optimizer, args):
|
116 |
+
min_lr = args.min_learning_rate if args.min_learning_rate else args.learning_rate / 10.0
|
117 |
+
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, self.epochs, eta_min=min_lr)
|
118 |
+
|
119 |
+
|
120 |
+
class PerformanceCallback(BaseCallback):
|
121 |
+
def __init__(self, logger, batch_size: int, warmup_epochs: int = 1, mode: str = 'train'):
|
122 |
+
self.batch_size = batch_size
|
123 |
+
self.warmup_epochs = warmup_epochs
|
124 |
+
self.epoch = 0
|
125 |
+
self.timestamps = []
|
126 |
+
self.mode = mode
|
127 |
+
self.logger = logger
|
128 |
+
|
129 |
+
def on_batch_start(self):
|
130 |
+
if self.epoch >= self.warmup_epochs:
|
131 |
+
self.timestamps.append(time.time() * 1000.0)
|
132 |
+
|
133 |
+
def _log_perf(self):
|
134 |
+
stats = self.process_performance_stats()
|
135 |
+
for k, v in stats.items():
|
136 |
+
logging.info(f'performance {k}: {v}')
|
137 |
+
|
138 |
+
self.logger.log_metrics(stats)
|
139 |
+
|
140 |
+
def on_epoch_end(self):
|
141 |
+
self.epoch += 1
|
142 |
+
|
143 |
+
def on_fit_end(self):
|
144 |
+
if self.epoch > self.warmup_epochs:
|
145 |
+
self._log_perf()
|
146 |
+
self.timestamps = []
|
147 |
+
|
148 |
+
def process_performance_stats(self):
|
149 |
+
timestamps = np.asarray(self.timestamps)
|
150 |
+
deltas = np.diff(timestamps)
|
151 |
+
throughput = (self.batch_size / deltas).mean()
|
152 |
+
stats = {
|
153 |
+
f"throughput_{self.mode}": throughput,
|
154 |
+
f"latency_{self.mode}_mean": deltas.mean(),
|
155 |
+
f"total_time_{self.mode}": timestamps[-1] - timestamps[0],
|
156 |
+
}
|
157 |
+
for level in [90, 95, 99]:
|
158 |
+
stats.update({f"latency_{self.mode}_{level}": np.percentile(deltas, level)})
|
159 |
+
|
160 |
+
return stats
|
env/SE3Transformer/build/lib/se3_transformer/runtime/gpu_affinity.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import collections
|
25 |
+
import itertools
|
26 |
+
import math
|
27 |
+
import os
|
28 |
+
import pathlib
|
29 |
+
import re
|
30 |
+
|
31 |
+
import pynvml
|
32 |
+
|
33 |
+
|
34 |
+
class Device:
|
35 |
+
# assumes nvml returns list of 64 bit ints
|
36 |
+
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
|
37 |
+
|
38 |
+
def __init__(self, device_idx):
|
39 |
+
super().__init__()
|
40 |
+
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
|
41 |
+
|
42 |
+
def get_name(self):
|
43 |
+
return pynvml.nvmlDeviceGetName(self.handle)
|
44 |
+
|
45 |
+
def get_uuid(self):
|
46 |
+
return pynvml.nvmlDeviceGetUUID(self.handle)
|
47 |
+
|
48 |
+
def get_cpu_affinity(self):
|
49 |
+
affinity_string = ""
|
50 |
+
for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
|
51 |
+
# assume nvml returns list of 64 bit ints
|
52 |
+
affinity_string = "{:064b}".format(j) + affinity_string
|
53 |
+
|
54 |
+
affinity_list = [int(x) for x in affinity_string]
|
55 |
+
affinity_list.reverse() # so core 0 is in 0th element of list
|
56 |
+
|
57 |
+
ret = [i for i, e in enumerate(affinity_list) if e != 0]
|
58 |
+
return ret
|
59 |
+
|
60 |
+
|
61 |
+
def get_thread_siblings_list():
|
62 |
+
"""
|
63 |
+
Returns a list of 2-element integer tuples representing pairs of
|
64 |
+
hyperthreading cores.
|
65 |
+
"""
|
66 |
+
path = "/sys/devices/system/cpu/cpu*/topology/thread_siblings_list"
|
67 |
+
thread_siblings_list = []
|
68 |
+
pattern = re.compile(r"(\d+)\D(\d+)")
|
69 |
+
for fname in pathlib.Path(path[0]).glob(path[1:]):
|
70 |
+
with open(fname) as f:
|
71 |
+
content = f.read().strip()
|
72 |
+
res = pattern.findall(content)
|
73 |
+
if res:
|
74 |
+
pair = tuple(map(int, res[0]))
|
75 |
+
thread_siblings_list.append(pair)
|
76 |
+
return thread_siblings_list
|
77 |
+
|
78 |
+
|
79 |
+
def check_socket_affinities(socket_affinities):
|
80 |
+
# sets of cores should be either identical or disjoint
|
81 |
+
for i, j in itertools.product(socket_affinities, socket_affinities):
|
82 |
+
if not set(i) == set(j) and not set(i).isdisjoint(set(j)):
|
83 |
+
raise RuntimeError(f"Sets of cores should be either identical or disjoint, " f"but got {i} and {j}.")
|
84 |
+
|
85 |
+
|
86 |
+
def get_socket_affinities(nproc_per_node, exclude_unavailable_cores=True):
|
87 |
+
devices = [Device(i) for i in range(nproc_per_node)]
|
88 |
+
socket_affinities = [dev.get_cpu_affinity() for dev in devices]
|
89 |
+
|
90 |
+
if exclude_unavailable_cores:
|
91 |
+
available_cores = os.sched_getaffinity(0)
|
92 |
+
socket_affinities = [list(set(affinity) & available_cores) for affinity in socket_affinities]
|
93 |
+
|
94 |
+
check_socket_affinities(socket_affinities)
|
95 |
+
|
96 |
+
return socket_affinities
|
97 |
+
|
98 |
+
|
99 |
+
def set_socket_affinity(gpu_id):
|
100 |
+
"""
|
101 |
+
The process is assigned with all available logical CPU cores from the CPU
|
102 |
+
socket connected to the GPU with a given id.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
gpu_id: index of a GPU
|
106 |
+
"""
|
107 |
+
dev = Device(gpu_id)
|
108 |
+
affinity = dev.get_cpu_affinity()
|
109 |
+
os.sched_setaffinity(0, affinity)
|
110 |
+
|
111 |
+
|
112 |
+
def set_single_affinity(gpu_id):
|
113 |
+
"""
|
114 |
+
The process is assigned with the first available logical CPU core from the
|
115 |
+
list of all CPU cores from the CPU socket connected to the GPU with a given
|
116 |
+
id.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
gpu_id: index of a GPU
|
120 |
+
"""
|
121 |
+
dev = Device(gpu_id)
|
122 |
+
affinity = dev.get_cpu_affinity()
|
123 |
+
|
124 |
+
# exclude unavailable cores
|
125 |
+
available_cores = os.sched_getaffinity(0)
|
126 |
+
affinity = list(set(affinity) & available_cores)
|
127 |
+
os.sched_setaffinity(0, affinity[:1])
|
128 |
+
|
129 |
+
|
130 |
+
def set_single_unique_affinity(gpu_id, nproc_per_node):
|
131 |
+
"""
|
132 |
+
The process is assigned with a single unique available physical CPU core
|
133 |
+
from the list of all CPU cores from the CPU socket connected to the GPU with
|
134 |
+
a given id.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
gpu_id: index of a GPU
|
138 |
+
"""
|
139 |
+
socket_affinities = get_socket_affinities(nproc_per_node)
|
140 |
+
|
141 |
+
siblings_list = get_thread_siblings_list()
|
142 |
+
siblings_dict = dict(siblings_list)
|
143 |
+
|
144 |
+
# remove siblings
|
145 |
+
for idx, socket_affinity in enumerate(socket_affinities):
|
146 |
+
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
|
147 |
+
|
148 |
+
affinities = []
|
149 |
+
assigned = []
|
150 |
+
|
151 |
+
for socket_affinity in socket_affinities:
|
152 |
+
for core in socket_affinity:
|
153 |
+
if core not in assigned:
|
154 |
+
affinities.append([core])
|
155 |
+
assigned.append(core)
|
156 |
+
break
|
157 |
+
os.sched_setaffinity(0, affinities[gpu_id])
|
158 |
+
|
159 |
+
|
160 |
+
def set_socket_unique_affinity(gpu_id, nproc_per_node, mode, balanced=True):
|
161 |
+
"""
|
162 |
+
The process is assigned with an unique subset of available physical CPU
|
163 |
+
cores from the CPU socket connected to a GPU with a given id.
|
164 |
+
Assignment automatically includes hyperthreading siblings (if siblings are
|
165 |
+
available).
|
166 |
+
|
167 |
+
Args:
|
168 |
+
gpu_id: index of a GPU
|
169 |
+
nproc_per_node: total number of processes per node
|
170 |
+
mode: mode
|
171 |
+
balanced: assign an equal number of physical cores to each process
|
172 |
+
"""
|
173 |
+
socket_affinities = get_socket_affinities(nproc_per_node)
|
174 |
+
|
175 |
+
siblings_list = get_thread_siblings_list()
|
176 |
+
siblings_dict = dict(siblings_list)
|
177 |
+
|
178 |
+
# remove hyperthreading siblings
|
179 |
+
for idx, socket_affinity in enumerate(socket_affinities):
|
180 |
+
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
|
181 |
+
|
182 |
+
socket_affinities_to_device_ids = collections.defaultdict(list)
|
183 |
+
|
184 |
+
for idx, socket_affinity in enumerate(socket_affinities):
|
185 |
+
socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
|
186 |
+
|
187 |
+
# compute minimal number of physical cores per GPU across all GPUs and
|
188 |
+
# sockets, code assigns this number of cores per GPU if balanced == True
|
189 |
+
min_physical_cores_per_gpu = min(
|
190 |
+
[len(cores) // len(gpus) for cores, gpus in socket_affinities_to_device_ids.items()]
|
191 |
+
)
|
192 |
+
|
193 |
+
for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
|
194 |
+
devices_per_group = len(device_ids)
|
195 |
+
if balanced:
|
196 |
+
cores_per_device = min_physical_cores_per_gpu
|
197 |
+
socket_affinity = socket_affinity[: devices_per_group * min_physical_cores_per_gpu]
|
198 |
+
else:
|
199 |
+
cores_per_device = len(socket_affinity) // devices_per_group
|
200 |
+
|
201 |
+
for group_id, device_id in enumerate(device_ids):
|
202 |
+
if device_id == gpu_id:
|
203 |
+
|
204 |
+
# In theory there should be no difference in performance between
|
205 |
+
# 'interleaved' and 'continuous' pattern on Intel-based DGX-1,
|
206 |
+
# but 'continuous' should be better for DGX A100 because on AMD
|
207 |
+
# Rome 4 consecutive cores are sharing L3 cache.
|
208 |
+
# TODO: code doesn't attempt to automatically detect layout of
|
209 |
+
# L3 cache, also external environment may already exclude some
|
210 |
+
# cores, this code makes no attempt to detect it and to align
|
211 |
+
# mapping to multiples of 4.
|
212 |
+
|
213 |
+
if mode == "interleaved":
|
214 |
+
affinity = list(socket_affinity[group_id::devices_per_group])
|
215 |
+
elif mode == "continuous":
|
216 |
+
affinity = list(socket_affinity[group_id * cores_per_device: (group_id + 1) * cores_per_device])
|
217 |
+
else:
|
218 |
+
raise RuntimeError("Unknown set_socket_unique_affinity mode")
|
219 |
+
|
220 |
+
# unconditionally reintroduce hyperthreading siblings, this step
|
221 |
+
# may result in a different numbers of logical cores assigned to
|
222 |
+
# each GPU even if balanced == True (if hyperthreading siblings
|
223 |
+
# aren't available for a subset of cores due to some external
|
224 |
+
# constraints, siblings are re-added unconditionally, in the
|
225 |
+
# worst case unavailable logical core will be ignored by
|
226 |
+
# os.sched_setaffinity().
|
227 |
+
affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
|
228 |
+
os.sched_setaffinity(0, affinity)
|
229 |
+
|
230 |
+
|
231 |
+
def set_affinity(gpu_id, nproc_per_node, mode="socket_unique_continuous", balanced=True):
|
232 |
+
"""
|
233 |
+
The process is assigned with a proper CPU affinity which matches hardware
|
234 |
+
architecture on a given platform. Usually it improves and stabilizes
|
235 |
+
performance of deep learning training workloads.
|
236 |
+
|
237 |
+
This function assumes that the workload is running in multi-process
|
238 |
+
single-device mode (there are multiple training processes and each process
|
239 |
+
is running on a single GPU), which is typical for multi-GPU training
|
240 |
+
workloads using `torch.nn.parallel.DistributedDataParallel`.
|
241 |
+
|
242 |
+
Available affinity modes:
|
243 |
+
* 'socket' - the process is assigned with all available logical CPU cores
|
244 |
+
from the CPU socket connected to the GPU with a given id.
|
245 |
+
* 'single' - the process is assigned with the first available logical CPU
|
246 |
+
core from the list of all CPU cores from the CPU socket connected to the GPU
|
247 |
+
with a given id (multiple GPUs could be assigned with the same CPU core).
|
248 |
+
* 'single_unique' - the process is assigned with a single unique available
|
249 |
+
physical CPU core from the list of all CPU cores from the CPU socket
|
250 |
+
connected to the GPU with a given id.
|
251 |
+
* 'socket_unique_interleaved' - the process is assigned with an unique
|
252 |
+
subset of available physical CPU cores from the CPU socket connected to a
|
253 |
+
GPU with a given id, hyperthreading siblings are included automatically,
|
254 |
+
cores are assigned with interleaved indexing pattern
|
255 |
+
* 'socket_unique_continuous' - (the default) the process is assigned with an
|
256 |
+
unique subset of available physical CPU cores from the CPU socket connected
|
257 |
+
to a GPU with a given id, hyperthreading siblings are included
|
258 |
+
automatically, cores are assigned with continuous indexing pattern
|
259 |
+
|
260 |
+
'socket_unique_continuous' is the recommended mode for deep learning
|
261 |
+
training workloads on NVIDIA DGX machines.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
gpu_id: integer index of a GPU
|
265 |
+
nproc_per_node: number of processes per node
|
266 |
+
mode: affinity mode
|
267 |
+
balanced: assign an equal number of physical cores to each process,
|
268 |
+
affects only 'socket_unique_interleaved' and
|
269 |
+
'socket_unique_continuous' affinity modes
|
270 |
+
|
271 |
+
Returns a set of logical CPU cores on which the process is eligible to run.
|
272 |
+
|
273 |
+
Example:
|
274 |
+
|
275 |
+
import argparse
|
276 |
+
import os
|
277 |
+
|
278 |
+
import gpu_affinity
|
279 |
+
import torch
|
280 |
+
|
281 |
+
|
282 |
+
def main():
|
283 |
+
parser = argparse.ArgumentParser()
|
284 |
+
parser.add_argument(
|
285 |
+
'--local_rank',
|
286 |
+
type=int,
|
287 |
+
default=os.getenv('LOCAL_RANK', 0),
|
288 |
+
)
|
289 |
+
args = parser.parse_args()
|
290 |
+
|
291 |
+
nproc_per_node = torch.cuda.device_count()
|
292 |
+
|
293 |
+
affinity = gpu_affinity.set_affinity(args.local_rank, nproc_per_node)
|
294 |
+
print(f'{args.local_rank}: core affinity: {affinity}')
|
295 |
+
|
296 |
+
|
297 |
+
if __name__ == "__main__":
|
298 |
+
main()
|
299 |
+
|
300 |
+
Launch the example with:
|
301 |
+
python -m torch.distributed.launch --nproc_per_node <#GPUs> example.py
|
302 |
+
|
303 |
+
|
304 |
+
WARNING: On DGX A100 only a half of CPU cores have direct access to GPUs.
|
305 |
+
This function restricts execution only to the CPU cores directly connected
|
306 |
+
to GPUs, so on DGX A100 it will limit the code to half of CPU cores and half
|
307 |
+
of CPU memory bandwidth (which may be fine for many DL models).
|
308 |
+
"""
|
309 |
+
pynvml.nvmlInit()
|
310 |
+
|
311 |
+
if mode == "socket":
|
312 |
+
set_socket_affinity(gpu_id)
|
313 |
+
elif mode == "single":
|
314 |
+
set_single_affinity(gpu_id)
|
315 |
+
elif mode == "single_unique":
|
316 |
+
set_single_unique_affinity(gpu_id, nproc_per_node)
|
317 |
+
elif mode == "socket_unique_interleaved":
|
318 |
+
set_socket_unique_affinity(gpu_id, nproc_per_node, "interleaved", balanced)
|
319 |
+
elif mode == "socket_unique_continuous":
|
320 |
+
set_socket_unique_affinity(gpu_id, nproc_per_node, "continuous", balanced)
|
321 |
+
else:
|
322 |
+
raise RuntimeError("Unknown affinity mode")
|
323 |
+
|
324 |
+
affinity = os.sched_getaffinity(0)
|
325 |
+
return affinity
|
env/SE3Transformer/build/lib/se3_transformer/runtime/inference.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
from typing import List
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
from torch.nn.parallel import DistributedDataParallel
|
29 |
+
from torch.utils.data import DataLoader
|
30 |
+
from tqdm import tqdm
|
31 |
+
|
32 |
+
from se3_transformer.runtime import gpu_affinity
|
33 |
+
from se3_transformer.runtime.arguments import PARSER
|
34 |
+
from se3_transformer.runtime.callbacks import BaseCallback
|
35 |
+
from se3_transformer.runtime.loggers import DLLogger
|
36 |
+
from se3_transformer.runtime.utils import to_cuda, get_local_rank
|
37 |
+
|
38 |
+
|
39 |
+
@torch.inference_mode()
|
40 |
+
def evaluate(model: nn.Module,
|
41 |
+
dataloader: DataLoader,
|
42 |
+
callbacks: List[BaseCallback],
|
43 |
+
args):
|
44 |
+
model.eval()
|
45 |
+
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), unit='batch', desc=f'Evaluation',
|
46 |
+
leave=False, disable=(args.silent or get_local_rank() != 0)):
|
47 |
+
*input, target = to_cuda(batch)
|
48 |
+
|
49 |
+
for callback in callbacks:
|
50 |
+
callback.on_batch_start()
|
51 |
+
|
52 |
+
with torch.cuda.amp.autocast(enabled=args.amp):
|
53 |
+
pred = model(*input)
|
54 |
+
|
55 |
+
for callback in callbacks:
|
56 |
+
callback.on_validation_step(input, target, pred)
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
from se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
|
61 |
+
from se3_transformer.runtime.utils import init_distributed, seed_everything
|
62 |
+
from se3_transformer.model import SE3TransformerPooled, Fiber
|
63 |
+
from se3_transformer.data_loading import QM9DataModule
|
64 |
+
import torch.distributed as dist
|
65 |
+
import logging
|
66 |
+
import sys
|
67 |
+
|
68 |
+
is_distributed = init_distributed()
|
69 |
+
local_rank = get_local_rank()
|
70 |
+
args = PARSER.parse_args()
|
71 |
+
|
72 |
+
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
|
73 |
+
|
74 |
+
logging.info('====== SE(3)-Transformer ======')
|
75 |
+
logging.info('| Inference on the test set |')
|
76 |
+
logging.info('===============================')
|
77 |
+
|
78 |
+
if not args.benchmark and args.load_ckpt_path is None:
|
79 |
+
logging.error('No load_ckpt_path provided, you need to provide a saved model to evaluate')
|
80 |
+
sys.exit(1)
|
81 |
+
|
82 |
+
if args.benchmark:
|
83 |
+
logging.info('Running benchmark mode with one warmup pass')
|
84 |
+
|
85 |
+
if args.seed is not None:
|
86 |
+
seed_everything(args.seed)
|
87 |
+
|
88 |
+
major_cc, minor_cc = torch.cuda.get_device_capability()
|
89 |
+
|
90 |
+
logger = DLLogger(args.log_dir, filename=args.dllogger_name)
|
91 |
+
datamodule = QM9DataModule(**vars(args))
|
92 |
+
model = SE3TransformerPooled(
|
93 |
+
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
94 |
+
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
|
95 |
+
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
|
96 |
+
output_dim=1,
|
97 |
+
tensor_cores=(args.amp and major_cc >= 7) or major_cc >= 8, # use Tensor Cores more effectively
|
98 |
+
**vars(args)
|
99 |
+
)
|
100 |
+
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='test')]
|
101 |
+
|
102 |
+
model.to(device=torch.cuda.current_device())
|
103 |
+
if args.load_ckpt_path is not None:
|
104 |
+
checkpoint = torch.load(str(args.load_ckpt_path), map_location={'cuda:0': f'cuda:{local_rank}'})
|
105 |
+
model.load_state_dict(checkpoint['state_dict'])
|
106 |
+
|
107 |
+
if is_distributed:
|
108 |
+
nproc_per_node = torch.cuda.device_count()
|
109 |
+
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
|
110 |
+
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
111 |
+
|
112 |
+
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
|
113 |
+
evaluate(model,
|
114 |
+
test_dataloader,
|
115 |
+
callbacks,
|
116 |
+
args)
|
117 |
+
|
118 |
+
for callback in callbacks:
|
119 |
+
callback.on_validation_end()
|
120 |
+
|
121 |
+
if args.benchmark:
|
122 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
123 |
+
callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')]
|
124 |
+
for _ in range(6):
|
125 |
+
evaluate(model,
|
126 |
+
test_dataloader,
|
127 |
+
callbacks,
|
128 |
+
args)
|
129 |
+
callbacks[0].on_epoch_end()
|
130 |
+
|
131 |
+
callbacks[0].on_fit_end()
|
env/SE3Transformer/build/lib/se3_transformer/runtime/loggers.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import pathlib
|
25 |
+
from abc import ABC, abstractmethod
|
26 |
+
from enum import Enum
|
27 |
+
from typing import Dict, Any, Callable, Optional
|
28 |
+
|
29 |
+
import dllogger
|
30 |
+
import torch.distributed as dist
|
31 |
+
import wandb
|
32 |
+
from dllogger import Verbosity
|
33 |
+
|
34 |
+
from se3_transformer.runtime.utils import rank_zero_only
|
35 |
+
|
36 |
+
|
37 |
+
class Logger(ABC):
|
38 |
+
@rank_zero_only
|
39 |
+
@abstractmethod
|
40 |
+
def log_hyperparams(self, params):
|
41 |
+
pass
|
42 |
+
|
43 |
+
@rank_zero_only
|
44 |
+
@abstractmethod
|
45 |
+
def log_metrics(self, metrics, step=None):
|
46 |
+
pass
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def _sanitize_params(params):
|
50 |
+
def _sanitize(val):
|
51 |
+
if isinstance(val, Callable):
|
52 |
+
try:
|
53 |
+
_val = val()
|
54 |
+
if isinstance(_val, Callable):
|
55 |
+
return val.__name__
|
56 |
+
return _val
|
57 |
+
except Exception:
|
58 |
+
return getattr(val, "__name__", None)
|
59 |
+
elif isinstance(val, pathlib.Path) or isinstance(val, Enum):
|
60 |
+
return str(val)
|
61 |
+
return val
|
62 |
+
|
63 |
+
return {key: _sanitize(val) for key, val in params.items()}
|
64 |
+
|
65 |
+
|
66 |
+
class LoggerCollection(Logger):
|
67 |
+
def __init__(self, loggers):
|
68 |
+
super().__init__()
|
69 |
+
self.loggers = loggers
|
70 |
+
|
71 |
+
def __getitem__(self, index):
|
72 |
+
return [logger for logger in self.loggers][index]
|
73 |
+
|
74 |
+
@rank_zero_only
|
75 |
+
def log_metrics(self, metrics, step=None):
|
76 |
+
for logger in self.loggers:
|
77 |
+
logger.log_metrics(metrics, step)
|
78 |
+
|
79 |
+
@rank_zero_only
|
80 |
+
def log_hyperparams(self, params):
|
81 |
+
for logger in self.loggers:
|
82 |
+
logger.log_hyperparams(params)
|
83 |
+
|
84 |
+
|
85 |
+
class DLLogger(Logger):
|
86 |
+
def __init__(self, save_dir: pathlib.Path, filename: str):
|
87 |
+
super().__init__()
|
88 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
89 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
90 |
+
dllogger.init(
|
91 |
+
backends=[dllogger.JSONStreamBackend(Verbosity.DEFAULT, str(save_dir / filename))])
|
92 |
+
|
93 |
+
@rank_zero_only
|
94 |
+
def log_hyperparams(self, params):
|
95 |
+
params = self._sanitize_params(params)
|
96 |
+
dllogger.log(step="PARAMETER", data=params)
|
97 |
+
|
98 |
+
@rank_zero_only
|
99 |
+
def log_metrics(self, metrics, step=None):
|
100 |
+
if step is None:
|
101 |
+
step = tuple()
|
102 |
+
|
103 |
+
dllogger.log(step=step, data=metrics)
|
104 |
+
|
105 |
+
|
106 |
+
class WandbLogger(Logger):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
name: str,
|
110 |
+
save_dir: pathlib.Path,
|
111 |
+
id: Optional[str] = None,
|
112 |
+
project: Optional[str] = None
|
113 |
+
):
|
114 |
+
super().__init__()
|
115 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
116 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
117 |
+
self.experiment = wandb.init(name=name,
|
118 |
+
project=project,
|
119 |
+
id=id,
|
120 |
+
dir=str(save_dir),
|
121 |
+
resume='allow',
|
122 |
+
anonymous='must')
|
123 |
+
|
124 |
+
@rank_zero_only
|
125 |
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
126 |
+
params = self._sanitize_params(params)
|
127 |
+
self.experiment.config.update(params, allow_val_change=True)
|
128 |
+
|
129 |
+
@rank_zero_only
|
130 |
+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
131 |
+
if step is not None:
|
132 |
+
self.experiment.log({**metrics, 'epoch': step})
|
133 |
+
else:
|
134 |
+
self.experiment.log(metrics)
|
env/SE3Transformer/build/lib/se3_transformer/runtime/metrics.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
from abc import ABC, abstractmethod
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.distributed as dist
|
28 |
+
from torch import Tensor
|
29 |
+
|
30 |
+
|
31 |
+
class Metric(ABC):
|
32 |
+
""" Metric class with synchronization capabilities similar to TorchMetrics """
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
self.states = {}
|
36 |
+
|
37 |
+
def add_state(self, name: str, default: Tensor):
|
38 |
+
assert name not in self.states
|
39 |
+
self.states[name] = default.clone()
|
40 |
+
setattr(self, name, default)
|
41 |
+
|
42 |
+
def synchronize(self):
|
43 |
+
if dist.is_initialized():
|
44 |
+
for state in self.states:
|
45 |
+
dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD)
|
46 |
+
|
47 |
+
def __call__(self, *args, **kwargs):
|
48 |
+
self.update(*args, **kwargs)
|
49 |
+
|
50 |
+
def reset(self):
|
51 |
+
for name, default in self.states.items():
|
52 |
+
setattr(self, name, default.clone())
|
53 |
+
|
54 |
+
def compute(self):
|
55 |
+
self.synchronize()
|
56 |
+
value = self._compute().item()
|
57 |
+
self.reset()
|
58 |
+
return value
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def _compute(self):
|
62 |
+
pass
|
63 |
+
|
64 |
+
@abstractmethod
|
65 |
+
def update(self, preds: Tensor, targets: Tensor):
|
66 |
+
pass
|
67 |
+
|
68 |
+
|
69 |
+
class MeanAbsoluteError(Metric):
|
70 |
+
def __init__(self):
|
71 |
+
super().__init__()
|
72 |
+
self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda'))
|
73 |
+
self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda'))
|
74 |
+
|
75 |
+
def update(self, preds: Tensor, targets: Tensor):
|
76 |
+
preds = preds.detach()
|
77 |
+
n = preds.shape[0]
|
78 |
+
error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum()
|
79 |
+
self.total += n
|
80 |
+
self.error += error
|
81 |
+
|
82 |
+
def _compute(self):
|
83 |
+
return self.error / self.total
|