GlandVergil commited on
Commit
6b89792
·
verified ·
1 Parent(s): 0793996

Upload 597 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .github/CODEOWNERS +2 -0
  3. .ipynb_checkpoints/untitled-checkpoint.py +0 -0
  4. .rosetta-ci/.gitignore +3 -0
  5. .rosetta-ci/benchmark.py +410 -0
  6. .rosetta-ci/benchmark.template.ini +40 -0
  7. .rosetta-ci/hpc_drivers/__init__.py +5 -0
  8. .rosetta-ci/hpc_drivers/base.py +210 -0
  9. .rosetta-ci/hpc_drivers/multicore.py +184 -0
  10. .rosetta-ci/hpc_drivers/slurm.py +176 -0
  11. .rosetta-ci/test-sets.yaml +65 -0
  12. .rosetta-ci/tests/__init__.py +765 -0
  13. .rosetta-ci/tests/rfd.py +111 -0
  14. .rosetta-ci/tests/self.md +6 -0
  15. .rosetta-ci/tests/self.py +209 -0
  16. config/inference/base.yaml +136 -0
  17. config/inference/symmetry.yaml +26 -0
  18. docker/Dockerfile +50 -0
  19. env/SE3Transformer/.dockerignore +123 -0
  20. env/SE3Transformer/.gitignore +121 -0
  21. env/SE3Transformer/Dockerfile +58 -0
  22. env/SE3Transformer/LICENSE +7 -0
  23. env/SE3Transformer/NOTICE +7 -0
  24. env/SE3Transformer/README.md +580 -0
  25. env/SE3Transformer/build/lib/se3_transformer/__init__.py +0 -0
  26. env/SE3Transformer/build/lib/se3_transformer/data_loading/__init__.py +1 -0
  27. env/SE3Transformer/build/lib/se3_transformer/data_loading/data_module.py +63 -0
  28. env/SE3Transformer/build/lib/se3_transformer/data_loading/qm9.py +173 -0
  29. env/SE3Transformer/build/lib/se3_transformer/model/__init__.py +2 -0
  30. env/SE3Transformer/build/lib/se3_transformer/model/basis.py +178 -0
  31. env/SE3Transformer/build/lib/se3_transformer/model/fiber.py +144 -0
  32. env/SE3Transformer/build/lib/se3_transformer/model/layers/__init__.py +5 -0
  33. env/SE3Transformer/build/lib/se3_transformer/model/layers/attention.py +180 -0
  34. env/SE3Transformer/build/lib/se3_transformer/model/layers/convolution.py +336 -0
  35. env/SE3Transformer/build/lib/se3_transformer/model/layers/linear.py +59 -0
  36. env/SE3Transformer/build/lib/se3_transformer/model/layers/norm.py +83 -0
  37. env/SE3Transformer/build/lib/se3_transformer/model/layers/pooling.py +53 -0
  38. env/SE3Transformer/build/lib/se3_transformer/model/transformer.py +222 -0
  39. env/SE3Transformer/build/lib/se3_transformer/runtime/__init__.py +0 -0
  40. env/SE3Transformer/build/lib/se3_transformer/runtime/arguments.py +70 -0
  41. env/SE3Transformer/build/lib/se3_transformer/runtime/callbacks.py +160 -0
  42. env/SE3Transformer/build/lib/se3_transformer/runtime/gpu_affinity.py +325 -0
  43. env/SE3Transformer/build/lib/se3_transformer/runtime/inference.py +131 -0
  44. env/SE3Transformer/build/lib/se3_transformer/runtime/loggers.py +134 -0
  45. env/SE3Transformer/build/lib/se3_transformer/runtime/metrics.py +83 -0
  46. env/SE3Transformer/build/lib/se3_transformer/runtime/training.py +238 -0
  47. env/SE3Transformer/build/lib/se3_transformer/runtime/utils.py +130 -0
  48. env/SE3Transformer/build/lib/tests/__init__.py +0 -0
  49. env/SE3Transformer/build/lib/tests/test_equivariance.py +102 -0
  50. env/SE3Transformer/build/lib/tests/utils.py +60 -0
.gitattributes CHANGED
@@ -36,3 +36,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  RFdiffusion/env/SE3Transformer/images/se3-transformer.png filter=lfs diff=lfs merge=lfs -text
37
  RFdiffusion/img/diffusion_protein_gradient_2.jpg filter=lfs diff=lfs merge=lfs -text
38
  RFdiffusion/pyrosetta-2023.14+release.7132bdc754a-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
 
 
 
36
  RFdiffusion/env/SE3Transformer/images/se3-transformer.png filter=lfs diff=lfs merge=lfs -text
37
  RFdiffusion/img/diffusion_protein_gradient_2.jpg filter=lfs diff=lfs merge=lfs -text
38
  RFdiffusion/pyrosetta-2023.14+release.7132bdc754a-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
39
+ env/SE3Transformer/images/se3-transformer.png filter=lfs diff=lfs merge=lfs -text
40
+ img/diffusion_protein_gradient_2.jpg filter=lfs diff=lfs merge=lfs -text
.github/CODEOWNERS ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Benchmark scripts
2
+ /.rosetta-ci @lyskov
.ipynb_checkpoints/untitled-checkpoint.py ADDED
File without changes
.rosetta-ci/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.pyc
2
+ results/
3
+ benchmark.ubuntu.ini
.rosetta-ci/benchmark.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # :noTabs=true:
4
+
5
+ # (c) Copyright Rosetta Commons Member Institutions.
6
+ # (c) This file is part of the Rosetta software suite and is made available under license.
7
+ # (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
8
+ # (c) For more information, see http://www.rosettacommons.org. Questions about this can be
9
+ # (c) addressed to University of Washington CoMotion, email: license@uw.edu.
10
+
11
+ ## @file benchmark.py
12
+ ## @brief Run arbitrary Rosetta testing script
13
+ ## @author Sergey Lyskov
14
+
15
+ from __future__ import print_function
16
+
17
+ import os, os.path, sys, shutil, json, platform, re
18
+ import codecs
19
+
20
+ from importlib.machinery import SourceFileLoader
21
+
22
+ from configparser import ConfigParser, ExtendedInterpolation
23
+ import argparse
24
+
25
+ from tests import * # execute, Tests states and key names
26
+ from hpc_drivers import *
27
+
28
+
29
+ # Calculating value of Platform dict
30
+ Platform = {}
31
+ if sys.platform.startswith("linux"):
32
+ Platform['os'] = 'ubuntu' if os.path.isfile('/etc/lsb-release') and 'Ubuntu' in open('/etc/lsb-release').read() else 'linux' # can be linux1, linux2, etc
33
+ elif sys.platform == "darwin" : Platform['os'] = 'mac'
34
+ elif sys.platform == "cygwin" : Platform['os'] = 'cygwin'
35
+ elif sys.platform == "win32" : Platform['os'] = 'windows'
36
+ else: Platform['os'] = 'unknown'
37
+
38
+ #Platform['arch'] = platform.architecture()[0][:2] # PlatformBits
39
+ Platform['compiler'] = 'gcc' if Platform['os'] == 'linux' else 'clang'
40
+
41
+ Platform['python'] = sys.executable
42
+
43
+
44
+ def load_python_source_from_file(module_name, module_path):
45
+ ''' replacment for deprecated imp.load_source
46
+ '''
47
+ return SourceFileLoader(module_name, module_path).load_module()
48
+
49
+
50
+ class Setup(object):
51
+ __slots__ = 'test working_dir platform config compare debug'.split() # version daemon path_to_previous_test
52
+ def __init__(self, **attrs):
53
+ #self.daemon = True
54
+ for k, v in attrs.items():
55
+ if k in self.__slots__: setattr(self, k, v)
56
+
57
+
58
+ def setup_from_options(options):
59
+ ''' Create Setup object based on user supplied options, config files and auto-detection
60
+ '''
61
+ platform = dict(Platform)
62
+
63
+ if options.suffix: options.suffix = '.' + options.suffix
64
+
65
+ platform['extras'] = options.extras.split(',') if options.extras else []
66
+ platform['python'] = options.python
67
+ #platform['options'] = json.loads( options.options ) if options.options else {}
68
+
69
+ if options.memory: memory = options.memory
70
+ elif platform['os'] in ['linux', 'ubuntu']: memory = int( execute('Getting memory info...', 'free -m', terminate_on_failure=False, silent=True, silence_output_on_errors=True, return_='output').split('\n')[1].split()[1]) // 1024
71
+ elif platform['os'] == 'mac': memory = int( execute('Getting memory info...', 'sysctl -a | grep hw.memsize', terminate_on_failure=False, silent=True, silence_output_on_errors=True, return_='output').split()[1]) // 1024 // 1024 // 1024
72
+
73
+ platform['compiler'] = options.compiler
74
+
75
+ if os.path.isfile(options.config):
76
+ with open(options.config) as f:
77
+ if '%(here)s' in f.read():
78
+ print(f"\n\n>>> ERROR file `{options.config}` seems to be in outdated format! Please use benchmark.template.ini to update it.")
79
+ sys.exit(1)
80
+
81
+ user_config = ConfigParser(
82
+ dict(
83
+ _here_ = os.path.abspath('./'),
84
+ _user_home_ = os.environ['HOME']
85
+ ),
86
+ interpolation = ExtendedInterpolation()
87
+ )
88
+
89
+ with open(options.config) as f: user_config.readfp(f)
90
+
91
+ else:
92
+ print(f"\n\n>>> Config file `{options.config}` not found. You may want to manually copy `benchmark.ini.template` to `{options.config}` and edit the settings\n\n")
93
+ user_config = ConfigParser()
94
+ user_config.set('main', 'cpu_count', '1')
95
+ user_config.set('main', 'hpc_driver', 'MultiCore')
96
+ user_config.set('main', 'branch', 'unknown')
97
+ user_config.set('main', 'revision', '42')
98
+ user_config.set('main', 'user_name', 'Jane Roe')
99
+ user_config.set('main', 'user_email', 'jane.roe@university.edu')
100
+ user_config.add_section('main')
101
+
102
+ if options.jobs: user_config.set('main', 'cpu_count', str(options.jobs) )
103
+ user_config.set('main', 'memory', str(memory) )
104
+
105
+ if options.mount:
106
+ for m in options.mount:
107
+ key, _, path = m.partition(':')
108
+ user_config.set('mount', key, path)
109
+
110
+ #config = Config.items('config')
111
+ #for section in config.sections(): print('Config section: ', section, dict(config.items(section)))
112
+ #config = { section: dict(Config.items(section)) for section in Config.sections() }
113
+
114
+ config = { k : d for k, d in user_config['main'].items() if k not in user_config[user_config.default_section] }
115
+ config['mounts'] = { k : d for k, d in user_config['mount'].items() if k not in user_config[user_config.default_section] }
116
+
117
+ #print(json.dumps(config, sort_keys=True, indent=2)); sys.exit(1)
118
+
119
+ #config.update( config.pop('config').items() )
120
+
121
+ config = dict(config,
122
+ cpu_count = user_config.getint('main', 'cpu_count'),
123
+ memory = memory,
124
+ revision = user_config.getint('main', 'revision'),
125
+ emulation=True,
126
+ ) # debug=options.debug,
127
+
128
+ if 'results_root' not in config: config['results_root'] = os.path.abspath('./results/')
129
+
130
+ if 'prefix' in config:
131
+ assert os.path.isabs( config['prefix'] ), f'ERROR: `prefix` path must be absolute! Got: {config["prefix"]}'
132
+
133
+ else: config['prefix'] = os.path.abspath( config['results_root'] + '/prefix')
134
+
135
+ config['merge_head'] = options.merge_head
136
+ config['merge_base'] = options.merge_base
137
+
138
+ if options.skip_compile is not None: config['skip_compile'] = options.skip_compile
139
+
140
+ #print(f'Results path: {config["results_root"]}')
141
+ #print('Config:{}, Platform:{}'.format(json.dumps(config, sort_keys=True, indent=2), Platform))
142
+
143
+ if options.compare: print('Comparing tests {} with suffixes: {}'.format(options.args, options.compare) )
144
+ else: print('Running tests: {}'.format(options.args) )
145
+
146
+ if len(options.args) != 1: print('Error: Single test-name-to-run should be supplied!'); sys.exit(1)
147
+ else:
148
+ test = options.args[0]
149
+ if test.startswith('tests/'): test = test.partition('tests/')[2][:-3] # removing dir prefix and .py suffix
150
+
151
+ if options.compare:
152
+ compare = options.compare[0], options.compare[1] # (this test suffix, previous test suffix)
153
+ working_dir = os.path.abspath( config['results_root'] + f'/{platform["os"]}.{test}' ) # will be a root dir with sub-dirs (options.compare[0], options.compare[1])
154
+ else:
155
+ compare = None
156
+ working_dir = os.path.abspath( config['results_root'] + f'/{platform["os"]}.{test}{options.suffix}' )
157
+
158
+
159
+ if os.path.isdir(working_dir): shutil.rmtree(working_dir); #print('Removing old job dir %s...' % working_dir) # remove old dir if any
160
+ os.makedirs(working_dir)
161
+
162
+ setup = Setup(
163
+ test = test,
164
+ working_dir = working_dir,
165
+ platform = platform,
166
+ config = config,
167
+ compare = compare,
168
+ debug = options.debug,
169
+ #daemon = False,
170
+ )
171
+
172
+ setup_as_json = json.dumps( { k : getattr(setup, k) for k in setup.__slots__}, sort_keys=True, indent=2)
173
+ with open(working_dir + '/.setup.json', 'w') as f: f.write(setup_as_json)
174
+
175
+ #print(f'Detected hardware platform: {Platform}')
176
+ print(f'Setup: {setup_as_json}')
177
+ return setup
178
+
179
+
180
+ def truncate_log(log):
181
+ _max_log_size_ = 1024*1024*1
182
+ _max_line_size_ = _max_log_size_ // 2
183
+
184
+ if len(log) > _max_log_size_:
185
+ new = log
186
+ lines = log.split('\n')
187
+
188
+ if len(lines) > 256:
189
+ new_lines = lines[:32] + ['...truncated...'] + lines[-128:]
190
+ new = '\n'.join(new_lines)
191
+
192
+ if len(new) > _max_log_size_: # special case for Ninja logs that does not use \n
193
+ lines = re.split(r'[\r\n]*', log) #t.log.split('\r')
194
+ if len(lines) > 256: new = '\n'.join( lines[:32] + ['...truncated...'] + lines[-128:] )
195
+
196
+ if len(new) > _max_log_size_: # going to try to truncate each individual line...
197
+ print(f'Trying to truncate log line-by-line...')
198
+ new = '\n'.join( (
199
+ ( line[:_max_line_size_//3] + '...truncated...' + line[-_max_line_size_//3:] ) if line > _max_line_size_ else line
200
+ for line in new_lines ) )
201
+
202
+ if len(new) > _max_log_size_: # fall-back strategy in case all of the above failed...
203
+ print(f'WARNING: could not truncate log line-by-line, falling back to raw truncate...')
204
+ new = 'WARNING: could not truncate test log line-by-line, falling back to raw truncate!\n...truncated...\n' + ( '\n'.join(lines) )[-_max_log_size_+256:]
205
+
206
+ print( 'Trunacting test output log: {0}MiB --> {1}MiB'.format(len(log)/1024/1024, len(new)/1024/1024) )
207
+
208
+ log = new
209
+
210
+ return log
211
+
212
+ def truncate_results_logs(results):
213
+ results[_LogKey_] = truncate_log( results[_LogKey_] )
214
+ if _ResultsKey_ in results and _TestsKey_ in results[_ResultsKey_]:
215
+ tests = results[_ResultsKey_][_TestsKey_]
216
+ for test in tests:
217
+ tests[test][_LogKey_] = truncate_log( tests[test][_LogKey_] )
218
+
219
+
220
+ def find_test_description(test_name, test_script_file_name):
221
+ ''' return content of test-description file if any or None if no description was found
222
+ '''
223
+
224
+ def find_description_file(prefix, test_name):
225
+ fname = prefix + test_name + '.md'
226
+ if os.path.isfile(fname): return fname
227
+ return prefix + 'md'
228
+
229
+ description_file_name = find_description_file( test_script_file_name[:-len('command.py')] + 'description.', test_name) if test_script_file_name.endswith('/command.py') else find_description_file(test_script_file_name[:-len('py')], test_name)
230
+
231
+ if description_file_name and os.path.isfile(description_file_name):
232
+ print(f'Found test suite description in file: {description_file_name!r}')
233
+ with open(description_file_name, encoding='utf-8', errors='backslashreplace') as f: description = f.read()
234
+ return description
235
+
236
+ else: return None
237
+
238
+
239
+
240
+ def run_test(setup):
241
+ #print(f'{setup!r}')
242
+ suite, rest = setup.test.split('.'), []
243
+ while suite:
244
+ #print( f'suite: {suite}, test: {rest}' )
245
+
246
+ file_name = '/'.join( ['tests'] + suite ) + '.py'
247
+ if os.path.isfile(file_name): break
248
+
249
+ file_name = '/'.join( ['tests'] + suite ) + '/command.py'
250
+ if os.path.isfile(file_name): break
251
+
252
+ rest.insert(0, suite.pop())
253
+
254
+
255
+ test = '.'.join( suite + rest )
256
+ test_name = '.'.join(rest)
257
+
258
+ print( f'Loading test from: {file_name}, suite+test: {test!r}, test: {test_name!r}' )
259
+ #test_suite = imp.load_source('test_suite', file_name)
260
+ test_suite = load_python_source_from_file('test_suite', file_name)
261
+
262
+ test_description = find_test_description(test_name, file_name)
263
+
264
+ if setup.compare:
265
+ #working_dir_1 = os.path.abspath( config['results_root'] + f'/{Platform["os"]}.{test}.{Options.compare[0]}' )
266
+ working_dir_1 = setup.working_dir + f'/{setup.compare[0]}'
267
+
268
+ working_dir_2 = setup.compare[1] and ( setup.working_dir + f'/{setup.compare[1]}' )
269
+ res_2_json_file_path = setup.compare[1] and f'{working_dir_2}/.execution.results.json'
270
+
271
+ with open(working_dir_1 + '/.execution.results.json') as f: res_1 = json.load(f).get(_ResultsKey_)
272
+
273
+ if setup.compare[1] and ( not os.path.isfile(res_2_json_file_path) ):
274
+ setup.compare[1] = None
275
+ state_override = _S_failed_
276
+ else:
277
+ state_override = None
278
+
279
+ if setup.compare[1] == None: res_2, working_dir_2 = None, None
280
+ else:
281
+ with open(res_2_json_file_path) as f: res_2 = json.load(f).get(_ResultsKey_)
282
+
283
+ res = test_suite.compare(test, res_1, working_dir_1, res_2, working_dir_2)
284
+
285
+ if state_override:
286
+ log_prefix = \
287
+ f'WARNING: Previous test results does not have `.execution.results.json` file, so comparision with None was performed instead!\n' \
288
+ f'WARNING: Overriding calcualted test state `{res[_StateKey_]}` → `{_S_failed_}`...\n\n'
289
+
290
+ res[_LogKey_] = log_prefix + res[_LogKey_]
291
+ res[_StateKey_] = _S_failed_
292
+
293
+
294
+ # # Caution! Some of the strings in the result object may be unicode. Be robust to unicode in the log messages.
295
+ # with codecs.open(setup.working_dir+'/.comparison.log.txt', 'w', encoding='utf-8', errors='replace') as f: f.write( truncate_log( res[_LogKey_] ) )
296
+ # res[_LogKey_] = truncate_log( res[_LogKey_] )
297
+
298
+ # # Caution! Some of the strings in the result object may be unicode. Be robust to unicode in the log messages.
299
+ with codecs.open(setup.working_dir+'/.comparison.log.txt', 'w', encoding='utf-8', errors='replace') as f: f.write(res[_LogKey_])
300
+ truncate_results_logs(res)
301
+
302
+ print( 'Comparison finished with output:\n{}'.format( res[_LogKey_] ) )
303
+
304
+ with open(setup.working_dir+'/.comparison.results.json', 'w') as f: json.dump(res, f, sort_keys=True, indent=2)
305
+
306
+ #print( 'Comparison finished with results:\n{}'.format( json.dumps(res, sort_keys=True, indent=2) ) )
307
+ if 'summary' in res: print('Summary section:\n{}'.format( json.dumps(res['summary'], sort_keys=True, indent=2) ) )
308
+
309
+ print( f'Output results of this comparison saved to {working_dir_1}/.comparison.results.json\nComparison log saved into {working_dir_1}/.comparison.log.txt' )
310
+
311
+
312
+ else:
313
+ working_dir = setup.working_dir #os.path.abspath( setup.config['results_root'] + f'/{platform["os"]}.{test}{options.suffix}' )
314
+
315
+ hpc_driver_name = setup.config['hpc_driver']
316
+ hpc_driver = None if hpc_driver_name in ['', 'none'] else eval(hpc_driver_name + '_HPC_Driver')(working_dir, setup.config, tracer=print, set_daemon_message=lambda x:None)
317
+
318
+ api_version = test_suite._api_version_ if hasattr(test_suite, '_api_version_') else ''
319
+
320
+ # if api_version < '1.0':
321
+ # res = test_suite.run(test=test_name, rosetta_dir=os.path.abspath('../..'), working_dir=working_dir, platform=dict(Platform), jobs=Config.cpu_count, verbose=True, debug=Options.debug)
322
+ # else:
323
+
324
+ if api_version == '1.0': res = test_suite.run(test=test_name, repository_root=os.path.abspath('./..'), working_dir=working_dir, platform=dict(setup.platform), config=setup.config, hpc_driver=hpc_driver, verbose=True, debug=setup.debug)
325
+ else:
326
+ print(f'Test benchmark api_version={api_version} is not supported!'); sys.exit(1)
327
+
328
+ if not isinstance(res, dict): print(f'Test returned result of type {type(res)} while dict-like object was expected, please check that test-script have correct `return` statment! Terminating...'); sys.exit(1)
329
+
330
+ # Caution! Some of the strings in the result object may be unicode. Be robust to unicode in the log messages
331
+ with codecs.open(working_dir+'/.execution.log.txt', 'w', encoding='utf-8', errors='replace') as f: f.write( res[_LogKey_] )
332
+
333
+ # res[_LogKey_] = truncate_log( res[_LogKey_] )
334
+ truncate_results_logs(res)
335
+
336
+ if _DescriptionKey_ not in res: res[_DescriptionKey_] = test_description
337
+
338
+ if res[_StateKey_] not in _S_Values_: print( 'Warning!!! Test {} failed with unknow result code: {}'.format(test_name, res[_StateKey_]) )
339
+ else: print( f'Test {test} finished with output:\n{res[_LogKey_]}\n----------------------------------------------------------------\nState: {res[_StateKey_]!r} | ', end='')
340
+
341
+ # JSON by default serializes to an ascii-encoded format
342
+ with open(working_dir+'/.execution.results.json', 'w') as f: json.dump(res, f, sort_keys=True, indent=2)
343
+
344
+ print( f'Output and full log of this test saved to:\n{working_dir}/.execution.results.json\n{working_dir}/.execution.log.txt' )
345
+
346
+
347
+
348
+
349
+
350
+
351
+ def main(args):
352
+ ''' Script to Run arbitrary Rosetta test
353
+ '''
354
+ parser = argparse.ArgumentParser(usage="Main testing script to run tests in the tests directory. "
355
+ "Use the --skip-compile to skip the build phase when testing locally. "
356
+ "Example Command: /benchmark.py -j2 integration.valgrind")
357
+
358
+ parser.add_argument('-j', '--jobs', default=0, type=int, help="Number of processors to use on when building. (default: use value from config file or 1)")
359
+
360
+ parser.add_argument('-m', '--memory', default=0, type=int, help="Amount of memory to use (default: use 2Gb per job")
361
+
362
+ parser.add_argument('--compiler', default=Platform['compiler'], help="Compiler to use")
363
+
364
+ #parser.add_argument('--python', default=('3.9' if Platform['os'] == 'mac' else '3.6'), help="Python interpreter to use")
365
+ parser.add_argument('--python', default=f'{sys.version_info.major}.{sys.version_info.minor}.s', help="Specify version of Python interpreter to use, for example '3.9'. If '.s' added to end of version string then use the same interpreter that was used to start this script. Default: '?.?.s'")
366
+
367
+ parser.add_argument("--extras", default='', help="Specify scons extras separated by ',': like --extras=mpi,static" )
368
+
369
+ parser.add_argument("--debug", action="store_true", dest="debug", default=False, help="Run specified test in debug mode (not with debug build!) this mean different things and depend on the test. Could be: skip the build phase, skip some of the test phases and so on. [off by default]" )
370
+
371
+ parser.add_argument("--suffix", default='', help="Specify ending suffix for test output dir. This is useful when you want to save test results in different dir for later comparison." )
372
+
373
+ parser.add_argument("--compare", nargs=2, help="Do not run the tests but instead compare previous results. Use --compare suffix1 suffix2" )
374
+
375
+ parser.add_argument("--config", default='benchmark.{os}.ini'.format(os=Platform['os']), action="store", help="Location of .ini file with additional options configuration. Optional.")
376
+
377
+ parser.add_argument("--skip-compile", dest='skip_compile', default=None, action="store_true", help="Skip the compilation phase. Assumes the binaries are already compiled locally.")
378
+
379
+ #parser.add_argument("--results-root", default=None, action="store", help="Location of `results` dir default is to use `./results`")
380
+
381
+ parser.add_argument("--setup", default=None, help="Specify JSON file with setup information. When this option supplied all other config and commandline options is ignored and auto-detection disable. Test, platform info will be gathered from provided JSON file. This option is designed to be used in daemon mode." )
382
+
383
+ parser.add_argument("--merge-head", default='HEAD', help="Specify SHA1/branch-name that will be used for `merge-head` value when simulating PR testing" )
384
+
385
+ parser.add_argument("--merge-base", default='origin/master', help="Specify SHA1/branch-name that will be used for `merge-base` value when simulating PR testing" )
386
+
387
+ parser.add_argument("--mount", action="append", help="Specify one of the mount points, like: --mount release_root:/some/path. This option could be used multiple times if needed" )
388
+
389
+
390
+ parser.add_argument('args', nargs=argparse.REMAINDER)
391
+
392
+ options = parser.parse_args(args=args[1:])
393
+
394
+ if any( [a.startswith('-') for a in options.args] ) :
395
+ print( '\nWARNING WARNING WARNING WARNING\n' )
396
+ print( '\tInterpreting', ' '.join(["'"+a+"'" for a in options.args if a.startswith('-')]), 'as test name(s), rather than as option(s).' )
397
+ print( "\tTry moving it before any test name, if that's not what you want." )
398
+ print( '\nWARNING WARNING WARNING WARNING\n' )
399
+
400
+
401
+ if options.setup:
402
+ with open(options.setup) as f: setup = Setup( **json.load(f) )
403
+
404
+ else:
405
+ setup = setup_from_options(options)
406
+
407
+ run_test(setup)
408
+
409
+
410
+ if __name__ == "__main__": main(sys.argv)
.rosetta-ci/benchmark.template.ini ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Benchmark script configuration file. Some of the tests require some system specific options to run. Please see benchmark.ini.template for list of available options.
3
+ #
4
+
5
+ [DEFAULT]
6
+
7
+ [main] # additional config-options for various tests. All this fields will be pass as keys in 'config' function argument
8
+
9
+ # how many jobs daemon can run on host machine (this is not related to HPC jobs)
10
+ cpu_count = 24
11
+
12
+ # how many memory in GB daemon can use on host machine (approximation, float)
13
+ memory = 64
14
+
15
+ # user name and email for user who submitted this test
16
+ user_name = Jane Roe
17
+ user_email = jane.roe@university.edu
18
+
19
+ # HPC Driver, might have one of the following values: MultiCore, Condor, Slurm or none if no HPC Driver should be configured
20
+ hpc_driver = MultiCore
21
+
22
+ # when running by daemons branch:revision will be set to appropriate values to represent currently checked version of main repository
23
+ branch = unknown
24
+ revision = 42
25
+
26
+ # path to directory where test results will be stored
27
+ results_root = ${_here_}/results
28
+
29
+ release_root = ./results/_release_
30
+
31
+ [slurm]
32
+ # head-node host name, if specified will be used to submit jobs
33
+ head_node =
34
+
35
+
36
+ [mount]
37
+ # list of key:path pairs that will be avalible as config.mounts during test run
38
+
39
+ # path to releases, leave empty if release production should not be supported by this daemon
40
+ release_root = ${_here_}/release
.rosetta-ci/hpc_drivers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # :noTabs=true:
3
+
4
+ from .multicore import MultiCore_HPC_Driver
5
+ from .slurm import Slurm_HPC_Driver
.rosetta-ci/hpc_drivers/base.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # :noTabs=true:
3
+
4
+ import os, sys, subprocess, stat
5
+ import time as time_module
6
+ import signal as signal_module
7
+
8
+ class NT: # named tuple
9
+ def __init__(self, **entries): self.__dict__.update(entries)
10
+ def __repr__(self):
11
+ r = 'NT: |'
12
+ for i in dir(self):
13
+ if not i.startswith('__') and not isinstance(getattr(self, i), types.MethodType): r += '{} --> {}, '.format(i, getattr(self, i))
14
+ return r[:-2]+'|'
15
+
16
+
17
+
18
+ class HPC_Exception(Exception):
19
+ def __init__(self, value): self.value = value
20
+ def __str__(self): return self.value
21
+
22
+
23
+
24
+ def execute(message, command_line, return_='status', until_successes=False, terminate_on_failure=True, silent=False, silence_output=False, tracer=print):
25
+ if not silent: tracer(message); tracer(command_line); sys.stdout.flush();
26
+ while True:
27
+
28
+ p = subprocess.Popen(command_line, bufsize=0, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
29
+ output, errors = p.communicate()
30
+
31
+ output = output + errors
32
+
33
+ output = output.decode(encoding="utf-8", errors="replace")
34
+
35
+ exit_code = p.returncode
36
+
37
+ if exit_code and not (silent or silence_output): tracer(output); sys.stdout.flush();
38
+
39
+ if exit_code and until_successes: pass # Thats right - redability COUNT!
40
+ else: break
41
+
42
+ tracer( "Error while executing {}: {}\n".format(message, output) )
43
+ tracer("Sleeping 60s... then I will retry...")
44
+ sys.stdout.flush();
45
+ time.sleep(60)
46
+
47
+ if return_ == 'tuple': return(exit_code, output)
48
+
49
+ if exit_code and terminate_on_failure:
50
+ tracer("\nEncounter error while executing: " + command_line)
51
+ if return_==True: return True
52
+ else: print("\nEncounter error while executing: " + command_line + '\n' + output); sys.exit(1)
53
+
54
+ if return_ == 'output': return output
55
+ else: return False
56
+
57
+
58
+ def Sleep(time_, message, dict_={}):
59
+ ''' Fancy sleep function '''
60
+ len_ = 0
61
+ for i in range(time_, 0, -1):
62
+ #print "Waiting for a new revision:%s... Sleeping...%d \r" % (sc.revision, i),
63
+ msg = message.format( **dict(dict_, time_left=i) )
64
+ print( msg, end='' )
65
+ len_ = max(len_, len(msg))
66
+ sys.stdout.flush()
67
+ time_module.sleep(1)
68
+
69
+ print( ' '*len_ + '\r', end='' ) # erazing sleep message
70
+
71
+
72
+ # Abstract class for HPC job submission
73
+ class HPC_Driver:
74
+ def __init__(self, working_dir, config, tracer=lambda x:None, set_daemon_message=lambda x:None):
75
+ self.working_dir = working_dir
76
+ self.config = config
77
+ self.cpu_usage = 0.0 # cummulative cpu usage in hours
78
+ self.tracer = tracer
79
+ self.set_daemon_message = set_daemon_message
80
+
81
+ self.cpu_count = self.config['cpu_count'] if type(config) == dict else self.config.getint('DEFAULT', 'cpu_count')
82
+
83
+ self.jobs = [] # list of all jobs currently running by this driver, Job class is driver depended, could be just int or something more complex
84
+
85
+ self.install_signal_handler()
86
+
87
+
88
+ def __del__(self):
89
+ self.remove_signal_handler()
90
+
91
+
92
+ def execute(self, executable, arguments, working_dir, log_dir=None, name='_no_name_', memory=256, time=24, shell_wrapper=False, block=True):
93
+ ''' Execute given command line on HPC cluster, must accumulate cpu hours in self.cpu_usage '''
94
+ if log_dir==None: log_dir=self.working_dir
95
+
96
+ if shell_wrapper:
97
+ shell_wrapper_sh = os.path.abspath(self.working_dir + '/hpc.{}.shell_wrapper.sh'.format(name))
98
+ with file(shell_wrapper_sh, 'w') as f: f.write('#!/bin/bash\n{} {}\n'.format(executable, arguments)); os.fchmod(f.fileno(), stat.S_IEXEC | stat.S_IREAD | stat.S_IWRITE)
99
+ executable, arguments = shell_wrapper_sh, ''
100
+
101
+ return self.submit_serial_hpc_job(name=name, executable=executable, arguments=arguments, working_dir=working_dir, log_dir=log_dir, jobs_to_queue=1, memory=memory, time=time, block=block, shell_wrapper=shell_wrapper)
102
+
103
+
104
+
105
+ @property
106
+ def number_of_cpu_per_node(self):
107
+ must_be_implemented_in_inherited_classes
108
+
109
+ @property
110
+ def maximum_number_of_mpi_cpu(self):
111
+ must_be_implemented_in_inherited_classes
112
+
113
+
114
+ def submit_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
115
+ print('submit_hpc_job is DEPRECATED and will be removed in near future, please use submit_serial_hpc_job instead!')
116
+ must_be_implemented_in_inherited_classes
117
+
118
+
119
+ def submit_serial_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
120
+ must_be_implemented_in_inherited_classes
121
+
122
+
123
+ def submit_mpi_hpc_job(self, name, executable, arguments, working_dir, log_dir, memory=512, time=12, block=True, process_coefficient="1", requested_nodes=1, requested_processes_per_node=1):
124
+ ''' submit jobs as MPI job
125
+ process_coefficient should be string representing fraction of process to launch on each node, for example '3 / 4' will start only 75% of MPI process's on each node
126
+ '''
127
+ must_be_implemented_in_inherited_classes
128
+
129
+
130
+ def cancel_all_jobs(self):
131
+ ''' Cancel all HPC jobs known to this driver, use this as signal handler for script termination '''
132
+ for j in self.jobs: self.cancel_job(j)
133
+
134
+ def block_until(self, silent, fn, *args, **kwargs):
135
+ '''
136
+ **fn must have the driver as the first argument**
137
+ example:
138
+ def fn(driver):
139
+ jobs = list(driver.jobs)
140
+ jobs = [job for job in jobs if not driver.complete(job)]
141
+ if len(jobs) <= 8:
142
+ return False # stops sleeping
143
+ return True # continues sleeping
144
+
145
+ for x in range(100):
146
+ hpc_driver.submit_hpc_job(...)
147
+ hpc_driver.block_until(False, fn)
148
+ '''
149
+ while fn(self, *args, **kwargs):
150
+ sys.stdout.flush()
151
+ time_module.sleep(60)
152
+ if not silent:
153
+ Sleep(1, '"Waiting for HPC job(s) to finish, sleeping {time_left}s\r')
154
+
155
+ def wait_until_complete(self, jobs=None, callback=None, silent=False):
156
+ ''' Helper function, wait until given jobs list is finished, if no argument is given waits until all jobs known by driver is finished '''
157
+ jobs = jobs if jobs else self.jobs
158
+
159
+ while jobs:
160
+ for j in jobs[:]:
161
+ if self.complete(j): jobs.remove(j)
162
+
163
+ if jobs:
164
+ #total_cpu_queued = sum( [j.jobs_queued for j in jobs] )
165
+ #total_cpu_running = sum( [j.cpu_running for j in jobs] )
166
+ #self.set_daemon_message("Waiting for HPC job(s) to finish... [{} process(es) in queue, {} process(es) running]".format(total_cpu_queued, total_cpu_running) )
167
+ #self.tracer("Waiting for HPC job(s) [{} process(es) in queue, {} process(es) running]... \r".format(total_cpu_queued, total_cpu_running), end='')
168
+ #print "Waiting for {} HPC jobs to finish... [{} jobs in queue, {} jobs running]... Sleeping 32s... \r".format(total_cpu_queued, cpu_queued+cpu_running, cpu_running),
169
+
170
+ self.set_daemon_message("Waiting for HPC {} job(s) to finish...".format( len(jobs) ) )
171
+ #self.tracer("Waiting for HPC {} job(s) to finish...".format( len(jobs) ) )
172
+
173
+ sys.stdout.flush()
174
+
175
+ if callback: callback()
176
+
177
+ if silent: time_module.sleep(64*1)
178
+ else: Sleep(64, '"Waiting for HPC {n_jobs} job(s) to finish, sleeping {time_left}s \r', dict(n_jobs=len(jobs)))
179
+
180
+
181
+
182
+ _signals_ = [signal_module.SIGINT, signal_module.SIGTERM, signal_module.SIGABRT]
183
+ def install_signal_handler(self):
184
+ def signal_handler(signal_, frame):
185
+ self.tracer('Recieved signal:{}... Canceling HPC jobs...'.format(signal_) )
186
+ self.cancel_all_jobs()
187
+ self.set_daemon_message( 'Remote daemon got terminated with signal:{}'.format(signal_) )
188
+ sys.exit(1)
189
+
190
+ for s in self._signals_: signal_module.signal(s, signal_handler)
191
+
192
+
193
+ def remove_signal_handler(self): # do we really need this???
194
+ try:
195
+ for s in self._signals_: signal_module.signal(s, signal_module.SIG_DFL)
196
+ #print('remove_signal_handler: done!')
197
+
198
+ except TypeError:
199
+ #print('remove_signal_handler: interpreted terminating, skipping remove_signal_handler...')
200
+ pass
201
+
202
+
203
+ def cancel_job(self, job_id):
204
+ must_be_implemented_in_inherited_classes
205
+
206
+
207
+ def complete(self, job_id):
208
+ ''' Return job completion status. Return True if job complered and False otherwise
209
+ '''
210
+ must_be_implemented_in_inherited_classes
.rosetta-ci/hpc_drivers/multicore.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # :noTabs=true:
3
+
4
+ import time as time_module
5
+ import codecs
6
+ import signal
7
+
8
+ import os, sys
9
+
10
+ try:
11
+ from .base import *
12
+
13
+ except ImportError: # workaround for B2 back-end's
14
+ import imp
15
+ imp.load_source(__name__, '/'.join(__file__.split('/')[:-1]) + '/base.py') # A bit of Python magic here, what we trying to say is this: from base import *, but path to base is calculated from our source location # from base import HPC_Driver, execute, NT
16
+
17
+
18
+ class MultiCore_HPC_Driver(HPC_Driver):
19
+
20
+ class JobID:
21
+ def __init__(self, pids=None):
22
+ self.pids = pids if pids else []
23
+
24
+
25
+ def __bool__(self): return bool(self.pids)
26
+
27
+
28
+ def __len__(self): return len(self.pids)
29
+
30
+
31
+ def add_pid(self, pid): self.pids.append(pid)
32
+
33
+
34
+ def remove_completed_pids(self):
35
+ for pid in self.pids[:]:
36
+ try:
37
+ r = os.waitpid(pid, os.WNOHANG)
38
+ if r == (pid, 0): self.pids.remove(pid) # process have ended without error
39
+ elif r[0] == pid : # process ended but with error, special case we will have to wait for all process to terminate and call system exit.
40
+ #self.cancel_job()
41
+ #sys.exit(1)
42
+ self.pids.remove(pid)
43
+ print('ERROR: Some of the HPC jobs terminated abnormally! Please see HPC logs for details.')
44
+
45
+ except ChildProcessError: self.pids.remove(pid)
46
+
47
+
48
+ def cancel(self):
49
+ for pid in self.pids:
50
+ try:
51
+ os.killpg(os.getpgid(pid), signal.SIGKILL)
52
+ except ChildProcessError: pass
53
+
54
+ self.pids = []
55
+
56
+
57
+
58
+ def __init__(self, *args, **kwds):
59
+ HPC_Driver.__init__(self, *args, **kwds)
60
+ #print(f'MultiCore_HPC_Driver: cpu_count: {self.cpu_count}')
61
+
62
+
63
+ def remove_completed_jobs(self):
64
+ for job in self.jobs[:]: # Need to make a copy so we don't modify a list we're iterating over
65
+ job.remove_completed_pids()
66
+ if not job: self.jobs.remove(job)
67
+
68
+
69
+ @property
70
+ def process_count(self):
71
+ ''' return number of processes that currently ran by this driver instance
72
+ '''
73
+ return sum( map(len, self.jobs) )
74
+
75
+
76
+ def submit_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
77
+ print('submit_hpc_job is DEPRECATED and will be removed in near future, please use submit_serial_hpc_job instead!')
78
+ return self.submit_serial_hpc_job(name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory, time, block, shell_wrapper)
79
+
80
+
81
+ def submit_serial_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
82
+ cpu_usage = -time_module.time()/60./60.
83
+
84
+ if shell_wrapper:
85
+ shell_wrapper_sh = os.path.abspath(self.working_dir + f'/hpc.{name}.shell_wrapper.sh')
86
+ with open(shell_wrapper_sh, 'w') as f: f.write('#!/bin/bash\n{} {}\n'.format(executable, arguments)); os.fchmod(f.fileno(), stat.S_IEXEC | stat.S_IREAD | stat.S_IWRITE)
87
+ executable, arguments = shell_wrapper_sh, ''
88
+
89
+ def mfork():
90
+ ''' Check if number of child process is below cpu_count. And if it is - fork the new pocees and return its pid.
91
+ '''
92
+ while self.process_count >= self.cpu_count:
93
+ self.remove_completed_jobs()
94
+ if self.process_count >= self.cpu_count: time_module.sleep(.5)
95
+
96
+ sys.stdout.flush()
97
+ pid = os.fork()
98
+ # appending at caller level insted if pid: self.jobs.append(pid) # We are parent!
99
+ return pid
100
+
101
+ current_job = self.JobID()
102
+ process = 0
103
+ for i in range(jobs_to_queue):
104
+
105
+ pid = mfork()
106
+ if not pid: # we are child process
107
+ command_line = 'cd {} && {} {}'.format(working_dir, executable, arguments.format(process=process) )
108
+ exit_code, log = execute('Running job {}.{}...'.format(name, i), command_line, tracer=self.tracer, return_='tuple')
109
+ with codecs.open(log_dir+'/.hpc.{name}.{i:02d}.log'.format(**vars()), 'w', encoding='utf-8', errors='replace') as f:
110
+ f.write(command_line+'\n'+log)
111
+ if exit_code:
112
+ error_report = f'\n\n{command_line}\nERROR: PROCESS {name}.{i:02d} TERMINATED WITH NON-ZERO-EXIT-CODE {exit_code}!\n'
113
+ f.write(error_report)
114
+ print(log, error_report)
115
+
116
+ sys.exit(0)
117
+
118
+ else: # we are parent!
119
+ current_job.add_pid(pid)
120
+ # Need to potentially re-add to list, as remove_completed_jobs() might trim it.
121
+ if current_job not in self.jobs: self.jobs.append(current_job)
122
+
123
+ process += 1
124
+
125
+ if block:
126
+ #for p in all_queued_jobs: os.waitpid(p, 0) # waiting for all child process to termintate...
127
+
128
+ self.wait_until_complete(current_job)
129
+ self.remove_completed_jobs()
130
+
131
+ cpu_usage += time_module.time()/60./60.
132
+ self.cpu_usage += cpu_usage * jobs_to_queue # approximation...
133
+
134
+ current_job = self.JobID()
135
+
136
+ return current_job
137
+
138
+
139
+ @property
140
+ def number_of_cpu_per_node(self): return self.cpu_count
141
+
142
+
143
+ @property
144
+ def maximum_number_of_mpi_cpu(self): return self.cpu_count
145
+
146
+
147
+ def submit_mpi_hpc_job(self, name, executable, arguments, working_dir, log_dir, memory=512, time=12, block=True, process_coefficient="1", requested_nodes=1, requested_processes_per_node=1):
148
+
149
+ if requested_nodes > 1:
150
+ print( "WARNING: " + str( requested_nodes ) + " nodes were requested, but we're running locally, so only 1 node will be used." )
151
+
152
+ if requested_processes_per_node > self.cpu_count:
153
+ print( "WARNING: " + str(requested_processes_per_node) + " processes were requested, but I only have " + str(self.cpu_count) + " CPUs. Will launch " + str(self.cpu_count) + " processes." )
154
+ actual_processes = min( requested_processes_per_node, self.cpu_count )
155
+
156
+ cpu_usage = -time_module.time()/60./60.
157
+
158
+ arguments = arguments.format(process=0)
159
+
160
+ command_line = f'cd {working_dir} && mpirun -np {actual_processes} {executable} {arguments}'
161
+ log = execute(f'Running job {name}...', command_line, tracer=self.tracer, return_='output')
162
+ with codecs.open(log_dir+'/.hpc.{name}.log'.format(**vars()), 'w', encoding='utf-8', errors='replace') as f: f.write(command_line+'\n'+log)
163
+
164
+ cpu_usage += time_module.time()/60./60.
165
+ self.cpu_usage += cpu_usage * actual_processes # approximation...
166
+
167
+ # return None - we do not return anything from this version of submit which imply returning None which in turn will be treated as job-id for already finished job
168
+
169
+
170
+ def complete(self, job_id):
171
+ ''' Return job completion status. Return True if job completed and False otherwise
172
+ '''
173
+ self.remove_completed_jobs()
174
+ return job_id not in self.jobs
175
+
176
+
177
+ def cancel_job(self, job):
178
+ job.cancel();
179
+ if job in self.jobs:
180
+ self.jobs.remove(job)
181
+
182
+
183
+ def __repr__(self):
184
+ return 'MultiCore_HPC_Driver<>'
.rosetta-ci/hpc_drivers/slurm.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # :noTabs=true:
3
+
4
+ import os, sys, time, collections, math
5
+ import stat as stat_module
6
+
7
+
8
+ try:
9
+ from .base import *
10
+
11
+ except ImportError: # workaround for B2 back-end's
12
+ import imp
13
+ imp.load_source(__name__, '/'.join(__file__.split('/')[:-1]) + '/base.py') # A bit of Python magic here, what we trying to say is this: from base import *, but path to base is calculated from our source location # from base import HPC_Driver, execute, NT
14
+
15
+
16
+ _T_slurm_array_job_template_ = '''\
17
+ #!/bin/bash
18
+ #
19
+ #SBATCH --job-name={name}
20
+ #SBATCH --output={log_dir}/.hpc.%x.%a.output
21
+ #
22
+ #SBATCH --time={time}:00
23
+ #SBATCH --mem-per-cpu={memory}M
24
+ #SBATCH --chdir={working_dir}
25
+ #
26
+ #SBATCH --array=1-{jobs_to_queue}
27
+
28
+ srun {executable} {arguments}
29
+ '''
30
+
31
+ _T_slurm_mpi_job_template_ = '''\
32
+ #!/bin/bash
33
+ #
34
+ #SBATCH --job-name={name}
35
+ #SBATCH --output={log_dir}/.hpc.%x.output
36
+ #
37
+ #SBATCH --time={time}:00
38
+ #SBATCH --mem-per-cpu={memory}M
39
+ #SBATCH --chdir={working_dir}
40
+ #
41
+ #SBATCH --ntasks={ntasks}
42
+
43
+ mpirun {executable} {arguments}
44
+ '''
45
+
46
+ class Slurm_HPC_Driver(HPC_Driver):
47
+ def head_node_execute(self, message, command_line, *args, **kwargs):
48
+ head_node = self.config['slurm'].get('head_node')
49
+
50
+ command_line, host = (f"ssh {head_node} cd `pwd` '&& {command_line}'", head_node) if head_node else (command_line, 'localhost')
51
+ return execute(f'Executiong on {host}: {message}' if message else '', command_line, *args, **kwargs)
52
+
53
+
54
+ # NodeGroup = collections.namedtuple('NodeGroup', 'nodes cores')
55
+
56
+ # @property
57
+ # def mpi_topology(self):
58
+ # ''' return list of NodeGroup's
59
+ # '''
60
+ # pass
61
+
62
+
63
+ # @property
64
+ # def number_of_cpu_per_node(self): return int( self.config['condor']['mpi_cpu_per_node'] )
65
+
66
+ # @property
67
+ # def maximum_number_of_mpi_cpu(self):
68
+ # return self.number_of_cpu_per_node * int( self.config['condor']['mpi_maximum_number_of_nodes'] )
69
+
70
+
71
+ # def complete(self, condor_job_id):
72
+ # ''' Return job completion status. Note that single hpc_job may contatin inner list of individual HPC jobs, True should be return if they all run in to completion.
73
+ # '''
74
+
75
+ # execute('Releasing condor jobs...', 'condor_release $USER', return_='tuple')
76
+
77
+ # s = execute('', 'condor_q $USER | grep $USER | grep {}'.format(condor_job_id), return_='output', terminate_on_failure=False).replace(' ', '').replace('\n', '')
78
+ # if s: return False
79
+
80
+ # # #setDaemonStatusAndPing('[Job #%s] Running... %s condor job(s) in queue...' % (self.id, len(s.split('\n') ) ) )
81
+ # # n_jobs = len(s.split('\n'))
82
+ # # s, o = execute('', 'condor_userprio -all | grep $USER@', return_='tuple')
83
+ # # if s == 0:
84
+ # # jobs_running = o.split()
85
+ # # jobs_running = 'XX' if len(jobs_running) < 4 else jobs_running[4]
86
+ # # self.set_daemon_message("Waiting for condor to finish HPC jobs... [{} jobs in HPC-Queue, {} CPU's used]".format(n_jobs, jobs_running) )
87
+ # # print "{} condor jobs in queue... Sleeping 32s... \r".format(n_jobs),
88
+ # # sys.stdout.flush()
89
+ # # time.sleep(32)
90
+ # else:
91
+
92
+ # #self.tracer('Waiting for condor to finish the jobs... DONE')
93
+ # self.jobs.remove(condor_job_id)
94
+ # self.cpu_usage += self.get_condor_accumulated_usage()
95
+ # return True # jobs already finished, we return empty list to prevent double counting of cpu_usage
96
+
97
+
98
+ def complete(self, slurm_job_id):
99
+ ''' Return True if job with given id is complete
100
+ '''
101
+
102
+ s = self.head_node_execute('', f'squeue -j {slurm_job_id} --noheader', return_='output', terminate_on_failure=False, silent=True)
103
+ if s: return False
104
+ else:
105
+ #self.tracer('Waiting for condor to finish the jobs... DONE')
106
+ self.jobs.remove(slurm_job_id)
107
+ return True # jobs already finished, we return empty list to prevent double counting of cpu_usage
108
+
109
+
110
+ def cancel_job(self, slurm_job_id):
111
+ self.head_node_execute(f'Slurm_HPC_Driver.canceling job {slurm_job_id}...', f'scancel {slurm_job_id}', terminate_on_failure=False)
112
+
113
+
114
+ # def submit_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
115
+ # print('submit_hpc_job is DEPRECATED and will be removed in near future, please use submit_serial_hpc_job instead!')
116
+ # return self.submit_serial_hpc_job(name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory, time, block, shell_wrapper)
117
+
118
+
119
+ def submit_serial_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
120
+
121
+ arguments = arguments.format(process='%a') # %a is SLURM array index
122
+ time = int( math.ceil(time*60) )
123
+
124
+ if shell_wrapper:
125
+ shell_wrapper_sh = os.path.abspath(self.working_dir + f'/hpc.{name}.shell_wrapper.sh')
126
+ with open(shell_wrapper_sh, 'w') as f: f.write('#!/bin/bash\n{} {}\n'.format(executable, arguments)); os.fchmod(f.fileno(), stat.S_IEXEC | stat.S_IREAD | stat.S_IWRITE)
127
+ executable, arguments = shell_wrapper_sh, ''
128
+
129
+ slurm_file = working_dir + f'/.hpc.{name}.slurm'
130
+
131
+ with open(slurm_file, 'w') as f: f.write( _T_slurm_array_job_template_.format( **vars() ) )
132
+
133
+
134
+ slurm_job_id = self.head_node_execute('Submitting SLURM array job...', f'cd {self.working_dir} && sbatch {slurm_file}',
135
+ tracer=self.tracer, return_='output'
136
+ ).split()[-1] # expecting something like `Submitted batch job 6122` in output
137
+
138
+
139
+ self.jobs.append(slurm_job_id)
140
+
141
+ if block:
142
+ self.wait_until_complete( [slurm_job_id] )
143
+ return None
144
+
145
+ else: return slurm_job_id
146
+
147
+
148
+
149
+
150
+
151
+ def submit_mpi_hpc_job(self, name, executable, arguments, working_dir, log_dir, ntasks, memory=512, time=12, block=True, shell_wrapper=False):
152
+ ''' submit jobs as MPI job
153
+ '''
154
+ arguments = arguments.format(process='0')
155
+ time = int( math.ceil(time*60) )
156
+
157
+ if shell_wrapper:
158
+ shell_wrapper_sh = os.path.abspath(self.working_dir + f'/hpc.{name}.shell_wrapper.sh')
159
+ with open(shell_wrapper_sh, 'w') as f: f.write('#!/bin/bash\n{} {}\n'.format(executable, arguments)); os.fchmod(f.fileno(), stat.S_IEXEC | stat.S_IREAD | stat.S_IWRITE)
160
+ executable, arguments = shell_wrapper_sh, ''
161
+
162
+ slurm_file = working_dir + f'/.hpc.{name}.slurm'
163
+
164
+ with open(slurm_file, 'w') as f: f.write( _T_slurm_mpi_job_template_.format( **vars() ) )
165
+
166
+ slurm_job_id = self.head_node_execute('Submitting SLURM mpi job...', f'cd {self.working_dir} && sbatch {slurm_file}',
167
+ tracer=self.tracer, return_='output'
168
+ ).split()[-1] # expecting something like `Submitted batch job 6122` in output
169
+
170
+ self.jobs.append(slurm_job_id)
171
+
172
+ if block:
173
+ self.wait_until_complete( [slurm_job_id] )
174
+ return None
175
+
176
+ else: return slurm_job_id
.rosetta-ci/test-sets.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # map platform-string → platform definiton
2
+ platforms:
3
+ ubuntu-20.04.gcc:
4
+ os: ubuntu-20.04
5
+ compiler: gcc
6
+ python: '3.9'
7
+
8
+ ubuntu-20.04.clang:
9
+ os: ubuntu-20.04
10
+ compiler: clang
11
+ python: '3.9'
12
+
13
+
14
+ # map of test-set-name → tests
15
+ test-sets:
16
+ main:
17
+ - ubuntu-20.04.clang.rfd
18
+
19
+ python:
20
+ - ubuntu-20.04.gcc.self.python
21
+ - ubuntu-20.04.clang.self.python
22
+
23
+ self:
24
+ - ubuntu-20.04.gcc.self.state
25
+ - ubuntu-20.04.gcc.self.subtests
26
+ - ubuntu-20.04.gcc.self.release
27
+
28
+
29
+
30
+ # map of GitHub-label → [test-set]
31
+ github-label-test-sets:
32
+ 00 main: [main]
33
+ 10 self: [self]
34
+ 16 python: [python]
35
+
36
+
37
+ # map of submit-page-category → tests
38
+ # tests that does not get assigned will be automatically displayed in 'other' category
39
+ category-tests:
40
+ main:
41
+ - rfd
42
+
43
+ self:
44
+ - self.state
45
+ - self.subtests
46
+ - self.release
47
+ - self.python
48
+
49
+
50
+ # map branch → test-set to
51
+ # specify list of tests that should be applied by-default during testing of each new commits to specific branch
52
+ branch-test-sets:
53
+ main: [main]
54
+ benchmark: [main, python]
55
+
56
+
57
+ # map branch → test-sets for pull-request's
58
+ # specify which test-sets should be scheduled for PR's by-default (ie in addition to GH labels applied)
59
+ # use empty branch name to specify defult value for (ie any branch not explicitly listed)
60
+ pull-request-branch-test-sets:
61
+ # specific test sets for benchmark branch
62
+ benchmark: ['main', 'python']
63
+
64
+ # default, will apply to PR's to any other branch
65
+ '': ['main']
.rosetta-ci/tests/__init__.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # :noTabs=true:
4
+
5
+ # (c) Copyright Rosetta Commons Member Institutions.
6
+ # (c) This file is part of the Rosetta software suite and is made available under license.
7
+ # (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
8
+ # (c) For more information, see http://www.rosettacommons.org. Questions about this can be
9
+ # (c) addressed to University of Washington CoMotion, email: license@uw.edu.
10
+
11
+ ## @file tests/__init__.py
12
+ ## @brief Common constats and types for all test types
13
+ ## @author Sergey Lyskov
14
+
15
+ import os, time, sys, shutil, codecs, urllib.request, imp, subprocess, json, hashlib # urllib.error, urllib.parse,
16
+ import platform as platform_module
17
+ import types as types_module
18
+
19
+ # ⚔ do not change wording below, it have to stay in sync with upstream (up to benchmark-model).
20
+ # Copied from benchmark-model, standard state code's for tests results.
21
+
22
+ __all__ = ['execute',
23
+ '_S_Values_', '_S_draft_', '_S_queued_', '_S_running_', '_S_passed_', '_S_failed_', '_S_build_failed_', '_S_script_failed_',
24
+ '_StateKey_', '_ResultsKey_', '_LogKey_', '_DescriptionKey_', '_TestsKey_',
25
+ '_multi_step_config_', '_multi_step_error_', '_multi_step_result_',
26
+ 'to_bytes',
27
+ ]
28
+
29
+ _S_draft_ = 'draft'
30
+ _S_queued_ = 'queued'
31
+ _S_running_ = 'running'
32
+ _S_passed_ = 'passed'
33
+ _S_failed_ = 'failed'
34
+ _S_build_failed_ = 'build failed'
35
+ _S_script_failed_ = 'script failed'
36
+ _S_queued_for_comparison_ = 'queued for comparison'
37
+
38
+ _S_Values_ = [_S_draft_, _S_queued_, _S_running_, _S_passed_, _S_failed_, _S_build_failed_, _S_script_failed_, _S_queued_for_comparison_]
39
+
40
+ _IgnoreKey_ = 'ignore'
41
+ _StateKey_ = 'state'
42
+ _ResultsKey_ = 'results'
43
+ _LogKey_ = 'log'
44
+ _DescriptionKey_ = 'description'
45
+ _TestsKey_ = 'tests'
46
+ _SummaryKey_ = 'summary'
47
+ _FailedKey_ = 'failed'
48
+ _TotalKey_ = 'total'
49
+ _PlotsKey_ = 'plots'
50
+ _FailedTestsKey_ = 'failed_tests'
51
+ _HtmlKey_ = 'html'
52
+
53
+ # file names for multi-step test files
54
+ _multi_step_config_ = 'config.json'
55
+ _multi_step_error_ = 'error.json'
56
+ _multi_step_result_ = 'result.json'
57
+
58
+ PyRosetta_unix_memory_requirement_per_cpu = 6 # Memory per sub-process in Gb's
59
+ PyRosetta_unix_unit_test_memory_requirement_per_cpu = 3.0 # Memory per sub-process in Gb's for running PyRosetta unit tests
60
+
61
+ # Commands to run all the scripts needed for setting up Rosetta compiles. (Run from main/source directory)
62
+ PRE_COMPILE_SETUP_SCRIPTS = [ "./update_options.sh", "./update_submodules.sh", "./update_ResidueType_enum_files.sh", "python version.py" ]
63
+
64
+ DEFAULT_PYTHON_VERSION='3.9'
65
+
66
+ # Standard funtions and classes below ---------------------------------------------------------------------------------
67
+
68
+ class BenchmarkError(Exception):
69
+ def __init__(self, value): self.value = value
70
+ def __repr__(self): return self.value
71
+ def __str__(self): return self.value
72
+
73
+
74
+ class NT: # named tuple
75
+ def __init__(self, **entries): self.__dict__.update(entries)
76
+ def __repr__(self):
77
+ r = 'NT: |'
78
+ for i in dir(self):
79
+ print(i)
80
+ if not i.startswith('__') and i != '_as_dict' and not isinstance(getattr(self, i), types_module.MethodType): r += '%s --> %s, ' % (i, getattr(self, i))
81
+ return r[:-2]+'|'
82
+
83
+ @property
84
+ def _as_dict(self):
85
+ return { a: getattr(self, a) for a in dir(self) if not a.startswith('__') and a != '_as_dict' and not isinstance(getattr(self, a), types_module.MethodType)}
86
+
87
+
88
+ def Tracer(verbose=False):
89
+ return print if verbose else lambda x: None
90
+
91
+
92
+ def to_unicode(b):
93
+ ''' Conver bytes to string and handle the errors. If argument is already in string - do nothing
94
+ '''
95
+ #return b if type(b) == unicode else unicode(b, 'utf-8', errors='replace')
96
+ return b if type(b) == str else str(b, 'utf-8', errors='backslashreplace')
97
+
98
+
99
+ def to_bytes(u):
100
+ ''' Conver string to bytes and handle the errors. If argument is already of type bytes - do nothing
101
+ '''
102
+ return u if type(u) == bytes else u.encode('utf-8', errors='backslashreplace')
103
+
104
+
105
+ ''' Python-2 version
106
+ def execute(message, commandline, return_=False, until_successes=False, terminate_on_failure=True, add_message_and_command_line_to_output=False):
107
+ message, commandline = to_unicode(message), to_unicode(commandline)
108
+
109
+ TR = Tracer()
110
+ TR(message); TR(commandline)
111
+ while True:
112
+ (res, output) = commands.getstatusoutput(commandline)
113
+ # Subprocess results will always be a bytes-string.
114
+ # Probably ASCII, but may have some Unicode characters.
115
+ # A UTF-8 decode will probably get decent results 99% of the time
116
+ # and the replace option will gracefully handle the rest.
117
+ output = to_unicode(output)
118
+
119
+ TR(output)
120
+
121
+ if res and until_successes: pass # Thats right - redability COUNT!
122
+ else: break
123
+
124
+ print( "Error while executing %s: %s\n" % (message, output) )
125
+ print( "Sleeping 60s... then I will retry..." )
126
+ time.sleep(60)
127
+
128
+ if add_message_and_command_line_to_output: output = message + '\nCommand line: ' + commandline + '\n' + output
129
+
130
+ if return_ == 'tuple': return(res, output)
131
+
132
+ if res and terminate_on_failure:
133
+ TR("\nEncounter error while executing: " + commandline)
134
+ if return_==True: return res
135
+ else:
136
+ print("\nEncounter error while executing: " + commandline + '\n' + output)
137
+ raise BenchmarkError("\nEncounter error while executing: " + commandline + '\n' + output)
138
+
139
+ if return_ == 'output': return output
140
+ else: return res
141
+ '''
142
+
143
+ def execute_through_subprocess(command_line):
144
+ # exit_code, output = subprocess.getstatusoutput(command_line)
145
+
146
+ # p = subprocess.Popen(command_line, bufsize=0, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
147
+ # output, errors = p.communicate()
148
+ # output = (output + errors).decode(encoding='utf-8', errors='backslashreplace')
149
+ # exit_code = p.returncode
150
+
151
+ # previous 'main' version based on subprocess module. Main issue that output of segfaults will not be captured since they generated by shell
152
+ p = subprocess.Popen(command_line, bufsize=0, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
153
+ output, errors = p.communicate()
154
+ # output = output + errors # ← we redirected stderr into same pipe as stdcout so errors is None, - no need to concatenate
155
+ output = output.decode(encoding='utf-8', errors='backslashreplace')
156
+ exit_code = p.returncode
157
+
158
+ return exit_code, output
159
+
160
+
161
+ def execute_through_pexpect(command_line):
162
+ import pexpect
163
+
164
+ child = pexpect.spawn('/bin/bash', ['-c', command_line])
165
+ child.expect(pexpect.EOF)
166
+ output = child.before.decode(encoding='utf-8', errors='backslashreplace')
167
+ child.close()
168
+ exit_code = child.signalstatus or child.exitstatus
169
+
170
+ return exit_code, output
171
+
172
+
173
+ def execute_through_pty(command_line):
174
+ import pty, select
175
+
176
+ if sys.platform == "darwin":
177
+
178
+ master, slave = pty.openpty()
179
+ p = subprocess.Popen(command_line, shell=True, stdout=slave, stdin=slave,
180
+ stderr=subprocess.STDOUT, close_fds=True)
181
+
182
+ buffer = []
183
+ while True:
184
+ try:
185
+ if select.select([master], [], [], 0.2)[0]: # has something to read
186
+ data = os.read(master, 1 << 22)
187
+ if data: buffer.append(data)
188
+
189
+ elif (p.poll() is not None) and (not select.select([master], [], [], 0.2)[0] ): break # process is finished and output buffer if fully read
190
+
191
+ except OSError: break # OSError will be raised when child process close PTY descriptior
192
+
193
+ output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace')
194
+
195
+ os.close(master)
196
+ os.close(slave)
197
+
198
+ p.wait()
199
+ exit_code = p.returncode
200
+
201
+ '''
202
+ buffer = []
203
+ while True:
204
+ if select.select([master], [], [], 0.2)[0]: # has something to read
205
+ data = os.read(master, 1 << 22)
206
+ if data: buffer.append(data)
207
+ # else: break # # EOF - well, technically process _should_ be finished here...
208
+
209
+ # elif time.sleep(1) or (p.poll() is not None): # process is finished (sleep here is intentional to trigger race condition, see solution for this on the next few lines)
210
+ # assert not select.select([master], [], [], 0.2)[0] # should be nothing left to read...
211
+ # break
212
+
213
+ elif (p.poll() is not None) and (not select.select([master], [], [], 0.2)[0] ): break # process is finished and output buffer if fully read
214
+
215
+ assert not select.select([master], [], [], 0.2)[0] # should be nothing left to read...
216
+
217
+ os.close(slave)
218
+ os.close(master)
219
+
220
+ output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace')
221
+ exit_code = p.returncode
222
+ '''
223
+
224
+ else:
225
+
226
+ master, slave = pty.openpty()
227
+ p = subprocess.Popen(command_line, shell=True, stdout=slave, stdin=slave,
228
+ stderr=subprocess.STDOUT, close_fds=True)
229
+
230
+ os.close(slave)
231
+
232
+ buffer = []
233
+ while True:
234
+ try:
235
+ data = os.read(master, 1 << 22)
236
+ if data: buffer.append(data)
237
+ except OSError: break # OSError will be raised when child process close PTY descriptior
238
+
239
+ output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace')
240
+
241
+ os.close(master)
242
+
243
+ p.wait()
244
+ exit_code = p.returncode
245
+
246
+ return exit_code, output
247
+
248
+
249
+
250
+ def execute(message, command_line, return_='status', until_successes=False, terminate_on_failure=True, silent=False, silence_output=False, silence_output_on_errors=False, add_message_and_command_line_to_output=False):
251
+ if not silent: print(message); print(command_line); sys.stdout.flush();
252
+ while True:
253
+
254
+ #exit_code, output = execute_through_subprocess(command_line)
255
+ #exit_code, output = execute_through_pexpect(command_line)
256
+ exit_code, output = execute_through_pty(command_line)
257
+
258
+ if (exit_code and not silence_output_on_errors) or not (silent or silence_output): print(output); sys.stdout.flush();
259
+
260
+ if exit_code and until_successes: pass # Thats right - redability COUNT!
261
+ else: break
262
+
263
+ print( "Error while executing {}: {}\n".format(message, output) )
264
+ print("Sleeping 60s... then I will retry...")
265
+ sys.stdout.flush();
266
+ time.sleep(60)
267
+
268
+ if add_message_and_command_line_to_output: output = message + '\nCommand line: ' + command_line + '\n' + output
269
+
270
+ if return_ == 'tuple' or return_ == tuple: return(exit_code, output)
271
+
272
+ if exit_code and terminate_on_failure:
273
+ print("\nEncounter error while executing: " + command_line)
274
+ if return_==True: return True
275
+ else:
276
+ print('\nEncounter error while executing: ' + command_line + '\n' + output);
277
+ raise BenchmarkError('\nEncounter error while executing: ' + command_line + '\n' + output)
278
+
279
+ if return_ == 'output': return output
280
+ else: return exit_code
281
+
282
+
283
+ def parallel_execute(name, jobs, rosetta_dir, working_dir, cpu_count, time=16):
284
+ ''' Execute command line in parallel on local host
285
+ time specifies the upper limit for cpu-usage runtime (in minutes) for any one process in the parallel execution.
286
+
287
+ jobs should be dict with following structure:
288
+ {
289
+ 'job-string-id-1’: command_line-1,
290
+ 'job-string-id-2’: command_line-2,
291
+ ...
292
+ }
293
+
294
+ return: dict with jobs-id's as keys and value as dict with 'output' and 'result' keys:
295
+ {
296
+ "job-string-id-1": {
297
+ "output": "stdout + stdderr output of command_line-1",
298
+ "result": <integer exit code for command_line-1>
299
+ },
300
+ "c2": {
301
+ "output": "stdout + stdderr output of command_line-2",
302
+ "result": <integer exit code for command_line-2>
303
+ },
304
+ ...
305
+ }
306
+ '''
307
+ job_file_name = working_dir + '/' + name
308
+ with open(job_file_name + '.json', 'w') as f: json.dump(jobs, f, sort_keys=True, indent=2) # JSON handles unicode internally
309
+ if time is not None:
310
+ allowed_time = int(time*60)
311
+ ulimit_command = f'ulimit -t {allowed_time} && '
312
+ else:
313
+ ulimit_command = ''
314
+ command = f'cd {working_dir} && ' + ulimit_command + f'{rosetta_dir}/tests/benchmark/util/parallel.py -j{cpu_count} {job_file_name}.json'
315
+ execute("Running {} in parallel with {} CPU's...".format(name, cpu_count), command )
316
+
317
+ with open(job_file_name+'.results.json') as f: return json.load(f)
318
+
319
+
320
+ def calculate_unique_prefix_path(platform, config):
321
+ ''' calculate path for prefix location that is unique for this machine and OS
322
+ '''
323
+ hostname = os.uname()[1]
324
+ return config['prefix'] + '/' + hostname + '/' + platform['os']
325
+
326
+
327
+ def get_python_include_and_lib(python):
328
+ ''' calculate python include dir and lib dir from given python executable path
329
+ '''
330
+ #python = os.path.realpath(python)
331
+ python_bin_dir = python.rpartition('/')[0]
332
+ python_config = f'{python} {python}-config' if python.endswith('2.7') else f'{python}-config'
333
+
334
+ #if not os.path.isfile(python_config): python_config = python_bin_dir + '/python-config'
335
+
336
+ info = execute('Getting python configuration info...', f'unset __PYVENV_LAUNCHER__ && cd {python_bin_dir} && PATH=.:$PATH && {python_config} --prefix --includes', return_='output').replace('\r', '').split('\n') # Python-3 only: --abiflags
337
+ python_prefix = info[0]
338
+ python_include_dir = info[1].split()[0][len('-I'):]
339
+ python_lib_dir = python_prefix + '/lib'
340
+ #python_abi_suffix = info[2]
341
+ #print(python_include_dir, python_lib_dir)
342
+
343
+ return NT(python_include_dir=python_include_dir, python_lib_dir=python_lib_dir)
344
+
345
+
346
+ def local_open_ssl_install(prefix, build_prefix, jobs):
347
+ ''' install OpenSSL at given prefix, return url of source archive
348
+ '''
349
+ #with tempfile.TemporaryDirectory('open_ssl_build', dir=prefix) as build_prefix:
350
+
351
+ url = 'https://www.openssl.org/source/openssl-1.1.1b.tar.gz'
352
+ #url = 'https://www.openssl.org/source/openssl-3.0.0.tar.gz'
353
+
354
+
355
+ archive = build_prefix + '/' + url.split('/')[-1]
356
+ build_dir = archive.rpartition('.tar.gz')[0]
357
+ if os.path.isdir(build_dir): shutil.rmtree(build_dir)
358
+
359
+ with open(archive, 'wb') as f:
360
+ response = urllib.request.urlopen(url)
361
+ f.write( response.read() )
362
+
363
+ execute('Unpacking {}'.format(archive), 'cd {build_prefix} && tar -xvzf {archive}'.format(**vars()) )
364
+
365
+ execute('Configuring...', f'cd {build_dir} && ./config --prefix={prefix}')
366
+ execute('Building...', f'cd {build_dir} && make -j{jobs}')
367
+ execute('Installing...', f'cd {build_dir} && make -j{jobs} install')
368
+
369
+ return url
370
+
371
+
372
+ def remove_pip_and_easy_install(prefix_root_path):
373
+ ''' remove `pip` and `easy_install` executable from given Python / virtual-environments install
374
+ '''
375
+ for f in os.listdir(prefix_root_path + '/bin'): # removing all pip's and easy_install's to make sure that environment is immutable
376
+ for p in ['pip', 'easy_install']:
377
+ if f.startswith(p): os.remove(prefix_root_path + '/bin/' + f)
378
+
379
+
380
+
381
+ def local_python_install(platform, config):
382
+ ''' Perform local install of given Python version and return path-to-python-interpreter, python_include_dir, python_lib_dir
383
+ If previous install is detected skip installiation.
384
+ Provided Python install will _persistent_ and _immutable_
385
+ '''
386
+ jobs = config['cpu_count']
387
+ compiler, cpp_compiler = ('clang', 'clang++') if platform['os'] == 'mac' else ('gcc', 'g++') # disregarding platform compiler setting and instead use default compiler for platform
388
+
389
+ python_version = platform.get('python', DEFAULT_PYTHON_VERSION)
390
+
391
+ if python_version.endswith('.s'):
392
+ assert python_version == f'{sys.version_info.major}.{sys.version_info.minor}.s'
393
+ #root = executable.rpartition('/bin/python')[0]
394
+ h = hashlib.md5(); h.update( (sys.executable + sys.version).encode('utf-8', errors='backslashreplace') ); hash = h.hexdigest()
395
+ return NT(
396
+ python = sys.executable,
397
+ root = None,
398
+ python_include_dir = None,
399
+ python_lib_dir = None,
400
+ version = python_version,
401
+ url = None,
402
+ platform = platform,
403
+ config = config,
404
+ hash = hash,
405
+ )
406
+
407
+ # deprecated, no longer needed
408
+ # python_version = {'python2' : '2.7',
409
+ # 'python2.7' : '2.7',
410
+ # 'python3' : '3.5',
411
+ # }.get(python_version, python_version)
412
+
413
+ # for security reasons we only allow installs for version listed here with hand-coded URL's
414
+ python_sources = {
415
+ '2.7' : 'https://www.python.org/ftp/python/2.7.18/Python-2.7.18.tgz',
416
+
417
+ '3.5' : 'https://www.python.org/ftp/python/3.5.9/Python-3.5.9.tgz',
418
+ '3.6' : 'https://www.python.org/ftp/python/3.6.15/Python-3.6.15.tgz',
419
+ '3.7' : 'https://www.python.org/ftp/python/3.7.14/Python-3.7.14.tgz',
420
+ '3.8' : 'https://www.python.org/ftp/python/3.8.14/Python-3.8.14.tgz',
421
+ '3.9' : 'https://www.python.org/ftp/python/3.9.14/Python-3.9.14.tgz',
422
+ '3.10' : 'https://www.python.org/ftp/python/3.10.10/Python-3.10.10.tgz',
423
+ '3.11' : 'https://www.python.org/ftp/python/3.11.2/Python-3.11.2.tgz',
424
+ }
425
+
426
+ # map of env -> ('shell-code-before ./configure', 'extra-arguments-for-configure')
427
+ extras = {
428
+ #('mac',) : ('__PYVENV_LAUNCHER__="" MACOSX_DEPLOYMENT_TARGET={}'.format(platform_module.mac_ver()[0]), ''), # __PYVENV_LAUNCHER__ now used by-default for all platform installs
429
+ ('mac',) : ('MACOSX_DEPLOYMENT_TARGET={}'.format(platform_module.mac_ver()[0]), ''),
430
+ ('linux', '2.7') : ('', '--enable-unicode=ucs4'),
431
+ ('ubuntu', '2.7') : ('', '--enable-unicode=ucs4'),
432
+ }
433
+
434
+ #packages = '' if (python_version[0] == '2' or python_version == '3.5' ) and platform['os'] == 'mac' else 'pip setuptools wheel' # 2.7 is now deprecated on Mac so some packages could not be installed
435
+ packages = 'setuptools'
436
+
437
+ url = python_sources[python_version]
438
+
439
+ extra = extras.get( (platform['os'],) , ('', '') )
440
+ extra = extras.get( (platform['os'], python_version) , extra)
441
+
442
+ extra = ('unset __PYVENV_LAUNCHER__ && ' + extra[0], extra[1])
443
+
444
+ options = '--with-ensurepip' #'--without-ensurepip'
445
+ signature = f'v1.5.1 url: {url}\noptions: {options}\ncompiler: {compiler}\nextra: {extra}\npackages: {packages}\n'
446
+
447
+ h = hashlib.md5(); h.update( signature.encode('utf-8', errors='backslashreplace') ); hash = h.hexdigest()
448
+
449
+ root = calculate_unique_prefix_path(platform, config) + '/python-' + python_version + '.' + compiler + '/' + hash
450
+
451
+ signature_file_name = root + '/.signature'
452
+
453
+ #activate = root + '/bin/activate'
454
+ executable = root + '/bin/python' + python_version
455
+
456
+ # if os.path.isfile(executable) and (not execute('Getting python configuration info...', '{executable}-config --prefix --includes'.format(**vars()), terminate_on_failure=False) ):
457
+ # print('found executable!')
458
+ # _, executable_version = execute('Checking Python interpreter version...', '{executable} --version'.format(**vars()), return_='tuple')
459
+ # executable_version = executable_version.split()[-1]
460
+ # else: executable_version = ''
461
+ # print('executable_version: {}'.format(executable_version))
462
+ #if executable_version != url.rpartition('Python-')[2][:-len('.tgz')]:
463
+
464
+ if os.path.isfile(signature_file_name) and open(signature_file_name).read() == signature:
465
+ #print('Install for Python-{} is detected, skipping installation procedure...'.format(python_version))
466
+ pass
467
+
468
+ else:
469
+ print( 'Installing Python-{python_version}, using {url} with extra:{extra}...'.format( **vars() ) )
470
+
471
+ if os.path.isdir(root): shutil.rmtree(root)
472
+
473
+ build_prefix = os.path.abspath(root + '/../build-python-{}'.format(python_version) )
474
+
475
+ if not os.path.isdir(root): os.makedirs(root)
476
+ if not os.path.isdir(build_prefix): os.makedirs(build_prefix)
477
+
478
+ platform_is_mac = True if platform['os'] in ['mac', 'm1'] else False
479
+ platform_is_linux = not platform_is_mac
480
+
481
+ #if False and platform['os'] == 'mac' and platform_module.machine() == 'arm64' and tuple( map(int, python_version.split('.') ) ) >= (3, 9):
482
+ if ( platform['os'] == 'mac' and python_version == '3.6' ) \
483
+ or ( platform_is_linux and python_version in ['3.10', '3.11'] ):
484
+ open_ssl_url = local_open_ssl_install(root, build_prefix, jobs)
485
+ options += f' --with-openssl={root} --with-openssl-rpath=auto'
486
+ #signature += 'OpenSSL install: ' + open_ssl_url + '\n'
487
+
488
+ archive = build_prefix + '/' + url.split('/')[-1]
489
+ build_dir = archive.rpartition('.tgz')[0]
490
+ if os.path.isdir(build_dir): shutil.rmtree(build_dir)
491
+
492
+ with open(archive, 'wb') as f:
493
+ #response = urllib2.urlopen(url)
494
+ response = urllib.request.urlopen(url)
495
+ f.write( response.read() )
496
+
497
+ #execute('Execution environment:', 'env'.format(**vars()) )
498
+
499
+ execute('Unpacking {}'.format(archive), 'cd {build_prefix} && tar -xvzf {archive}'.format(**vars()) )
500
+
501
+ #execute('Building and installing...', 'cd {} && CC={compiler} CXX={cpp_compiler} {extra[0]} ./configure {extra[1]} --prefix={root} && {extra[0]} make -j{jobs} && {extra[0]} make install'.format(build_dir, **locals()) )
502
+ execute('Configuring...', 'cd {} && CC={compiler} CXX={cpp_compiler} {extra[0]} ./configure {options} {extra[1]} --prefix={root}'.format(build_dir, **locals()) )
503
+ execute('Building...', 'cd {} && {extra[0]} make -j{jobs}'.format(build_dir, **locals()) )
504
+ execute('Installing...', 'cd {} && {extra[0]} make -j{jobs} install'.format(build_dir, **locals()) )
505
+
506
+ shutil.rmtree(build_prefix)
507
+
508
+ #execute('Updating setuptools...', f'cd {root} && {root}/bin/pip{python_version} install --upgrade setuptools wheel' )
509
+
510
+ # if 'certifi' not in packages:
511
+ # packages += ' certifi'
512
+
513
+ if packages: execute( f'Installing packages {packages}...', f'cd {root} && unset __PYVENV_LAUNCHER__ && {root}/bin/pip{python_version} install --upgrade {packages}' )
514
+ #if packages: execute( f'Installing packages {packages}...', f'cd {root} && unset __PYVENV_LAUNCHER__ && {executable} -m pip install --upgrade {packages}' )
515
+
516
+ remove_pip_and_easy_install(root) # removing all pip's and easy_install's to make sure that environment is immutable
517
+
518
+ with open(signature_file_name, 'w') as f: f.write(signature)
519
+
520
+ print( 'Installing Python-{python_version}, using {url} with extra:{extra}... Done.'.format( **vars() ) )
521
+
522
+ il = get_python_include_and_lib(executable)
523
+
524
+ return NT(
525
+ python = executable,
526
+ root = root,
527
+ python_include_dir = il.python_include_dir,
528
+ python_lib_dir = il.python_lib_dir,
529
+ version = python_version,
530
+ url = url,
531
+ platform = platform,
532
+ config = config,
533
+ hash = hash,
534
+ )
535
+
536
+
537
+
538
+ def setup_python_virtual_environment(working_dir, python_environment, packages=''):
539
+ ''' Deploy Python virtual environment at working_dir
540
+ '''
541
+
542
+ python = python_environment.python
543
+
544
+ execute('Setting up Python virtual environment...', 'unset __PYVENV_LAUNCHER__ && {python} -m venv --clear {working_dir}'.format(**vars()) )
545
+
546
+ activate = f'unset __PYVENV_LAUNCHER__ && . {working_dir}/bin/activate'
547
+
548
+ bin=working_dir+'/bin'
549
+
550
+ if packages: execute('Installing packages: {}...'.format(packages), 'unset __PYVENV_LAUNCHER__ && {bin}/python {bin}/pip install --upgrade pip setuptools && {bin}/python {bin}/pip install --progress-bar off {packages}'.format(**vars()) )
551
+ #if packages: execute('Installing packages: {}...'.format(packages), '{bin}/pip{python_environment.version} install {packages}'.format(**vars()) )
552
+
553
+ return NT(activate = activate, python = bin + '/python', root = working_dir, bin = bin)
554
+
555
+
556
+
557
+ def setup_persistent_python_virtual_environment(python_environment, packages):
558
+ ''' Setup _persistent_ and _immutable_ Python virtual environment which will be saved between test runs
559
+ '''
560
+
561
+ if python_environment.version.startswith('2.'):
562
+ assert not packages, f'ERROR: setup_persistent_python_virtual_environment does not support Python-2.* with non-empty package list!'
563
+ return NT(activate = ':', python = python_environment.python, root = python_environment.root, bin = python_environment.root + '/bin')
564
+
565
+ else:
566
+ #if 'certifi' not in packages: packages += ' certifi'
567
+
568
+ h = hashlib.md5()
569
+ h.update(f'v1.0.0 platform: {python_environment.platform} python_source_url: {python_environment.url} python-hash: {python_environment.hash} packages: {packages}'.encode('utf-8', errors='backslashreplace') )
570
+ hash = h.hexdigest()
571
+
572
+ prefix = calculate_unique_prefix_path(python_environment.platform, python_environment.config)
573
+
574
+ root = os.path.abspath( prefix + '/python_virtual_environments/' + '/python-' + python_environment.version + '/' + hash )
575
+ signature_file_name = root + '/.signature'
576
+ signature = f'setup_persistent_python_virtual_environment v1.0.0\npython: {python_environment.hash}\npackages: {packages}\n'
577
+
578
+ activate = f'unset __PYVENV_LAUNCHER__ && . {root}/bin/activate'
579
+ bin = f'{root}/bin'
580
+
581
+ if os.path.isfile(signature_file_name) and open(signature_file_name).read() == signature: pass
582
+ else:
583
+ if os.path.isdir(root): shutil.rmtree(root)
584
+ setup_python_virtual_environment(root, python_environment, packages=packages)
585
+ remove_pip_and_easy_install(root) # removing all pip's and easy_install's to make sure that environment is immutable
586
+ with open(signature_file_name, 'w') as f: f.write(signature)
587
+
588
+ return NT(activate = activate, python = bin + '/python', root = root, bin = bin, hash = hash)
589
+
590
+
591
+
592
+ def _get_path_to_conda_root(platform, config):
593
+ ''' Perform local (prefix) install of miniconda and return NT(activate, conda_root_dir, conda)
594
+ this function is for inner use only, - to setup custom conda environment inside your test use `setup_conda_virtual_environment` defined below
595
+ '''
596
+ miniconda_sources = {
597
+ 'mac' : 'https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh',
598
+ 'linux' : 'https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh',
599
+ 'aarch64': 'https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh',
600
+ 'ubuntu' : 'https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh',
601
+ 'm1' : 'https://repo.anaconda.com/miniconda/Miniconda3-py38_4.10.1-MacOSX-arm64.sh',
602
+ }
603
+
604
+ conda_sources = {
605
+ 'mac' : 'https://repo.continuum.io/archive/Anaconda3-2018.12-MacOSX-x86_64.sh',
606
+ 'linux' : 'https://repo.continuum.io/archive/Anaconda3-2018.12-Linux-x86_64.sh',
607
+ 'ubuntu' : 'https://repo.continuum.io/archive/Anaconda3-2018.12-Linux-x86_64.sh',
608
+ }
609
+
610
+ #platform_os = 'm1' if platform_module.machine() == 'arm64' else platform['os']
611
+ #url = miniconda_sources[ platform_os ]
612
+
613
+ platform_os = platform['os']
614
+ for o in 'alpine centos ubuntu'.split():
615
+ if platform_os.startswith(o): platform_os = 'linux'
616
+
617
+ url = miniconda_sources[platform_os]
618
+
619
+ version = '1'
620
+ channels = '' # conda-forge
621
+
622
+ #packages = ['conda-build gcc libgcc', 'libgcc=5.2.0'] # libgcc installs is workaround for "Anaconda libstdc++.so.6: version `GLIBCXX_3.4.20' not found", see: https://stackoverflow.com/questions/48453497/anaconda-libstdc-so-6-version-glibcxx-3-4-20-not-found
623
+ #packages = ['conda-build gcc'] # libgcc installs is workaround for "Anaconda libstdc++.so.6: version `GLIBCXX_3.4.20' not found", see: https://stackoverflow.com/questions/48453497/anaconda-libstdc-so-6-version-glibcxx-3-4-20-not-found
624
+ packages = ['conda-build anaconda-client conda-verify',]
625
+
626
+ signature = f'url: {url}\nversion: {version}\channels: {channels}\npackages: {packages}\n'
627
+
628
+ root = calculate_unique_prefix_path(platform, config) + '/conda'
629
+
630
+ signature_file_name = root + '/.signature'
631
+
632
+ # presense of __PYVENV_LAUNCHER__,PYTHONHOME, PYTHONPATH sometimes confuse Python so we have to unset them
633
+ unset = 'unset __PYVENV_LAUNCHER__ && unset PYTHONHOME && unset PYTHONPATH'
634
+ activate = unset + ' && . ' + root + '/bin/activate'
635
+
636
+ executable = root + '/bin/conda'
637
+
638
+
639
+ if os.path.isfile(signature_file_name) and open(signature_file_name).read() == signature:
640
+ print( f'Install for MiniConda is detected, skipping installation procedure...' )
641
+
642
+ else:
643
+ print( f'Installing MiniConda, using {url}...' )
644
+
645
+ if os.path.isdir(root): shutil.rmtree(root)
646
+
647
+ build_prefix = os.path.abspath(root + f'/../build-conda' )
648
+
649
+ #if not os.path.isdir(root): os.makedirs(root)
650
+ if not os.path.isdir(build_prefix): os.makedirs(build_prefix)
651
+
652
+ archive = build_prefix + '/' + url.split('/')[-1]
653
+
654
+ with open(archive, 'wb') as f:
655
+ response = urllib.request.urlopen(url)
656
+ f.write( response.read() )
657
+
658
+ execute('Installing conda...', f'cd {build_prefix} && {unset} && bash {archive} -b -p {root}' )
659
+
660
+ # conda update --yes --quiet -n base -c defaults conda
661
+
662
+ if channels: execute(f'Adding extra channles {channels}...', f'cd {build_prefix} && {activate} && conda config --add channels {channels}' )
663
+
664
+ for p in packages: execute(f'Installing conda packages: {p}...', f'cd {build_prefix} && {activate} && conda install --quiet --yes {p}' )
665
+
666
+ shutil.rmtree(build_prefix)
667
+
668
+ with open(signature_file_name, 'w') as f: f.write(signature)
669
+
670
+ print( f'Installing MiniConda, using {url}... Done.' )
671
+
672
+ execute(f'Updating conda base...', f'{activate} && conda update --all --yes' )
673
+ return NT(conda=executable, root=root, activate=activate, url=url)
674
+
675
+
676
+
677
+ def setup_conda_virtual_environment(working_dir, platform, config, packages=''):
678
+ ''' Deploy Conda virtual environment at working_dir
679
+ '''
680
+ conda_root_env = _get_path_to_conda_root(platform, config)
681
+ activate = conda_root_env.activate
682
+
683
+ python_version = platform.get('python', DEFAULT_PYTHON_VERSION)
684
+
685
+ prefix = os.path.abspath( working_dir + '/.conda-python-' + python_version )
686
+
687
+ command_line = f'conda create --quiet --yes --prefix {prefix} python={python_version}'
688
+
689
+ execute( f'Setting up Conda for Python-{python_version} virtual environment...', f'cd {working_dir} && {activate} && ( {command_line} || ( conda clean --yes && {command_line} ) )' )
690
+
691
+ activate = f'{activate} && conda activate {prefix}'
692
+
693
+ if packages: execute( f'Setting up extra packages {packages}...', f'cd {working_dir} && {activate} && conda install --quiet --yes {packages}' )
694
+
695
+ python = prefix + '/bin/python' + python_version
696
+
697
+ il = get_python_include_and_lib(python)
698
+
699
+ return NT(
700
+ activate = activate,
701
+ root = prefix,
702
+ python = python,
703
+ python_include_dir = il.python_include_dir,
704
+ python_lib_dir = il.python_lib_dir,
705
+ version = python_version,
706
+ activate_base = conda_root_env.activate,
707
+ url = prefix, # conda_root_env.url,
708
+ platform=platform,
709
+ config=config,
710
+ )
711
+
712
+
713
+
714
+ class FileLock():
715
+ ''' Implementation of file-lock object that could be use with Python `with` statement
716
+ '''
717
+
718
+ def __init__(self, file_name):
719
+ self.locked = False
720
+ self.file_name = file_name
721
+
722
+
723
+ def __enter__(self):
724
+ if not self.locked: self.acquire()
725
+ return self
726
+
727
+
728
+ def __exit__(self, exc_type, exc_value, traceback):
729
+ if self.locked: self.release()
730
+
731
+
732
+ def __del__(self):
733
+ self.release()
734
+
735
+
736
+ def acquire(self):
737
+ while True:
738
+ try:
739
+ os.close( os.open(self.file_name, os.O_CREAT | os.O_EXCL, mode=0o600) )
740
+ self.locked = True
741
+ break
742
+
743
+ except FileExistsError as e:
744
+ time.sleep(60)
745
+
746
+
747
+ def release(self):
748
+ if self.locked:
749
+ os.remove(self.file_name)
750
+ self.locked = False
751
+
752
+
753
+
754
+ def convert_submodule_urls_from_ssh_to_https(repository_root):
755
+ ''' switching submodules URL to HTTPS so we can clone without SSH key
756
+ '''
757
+ with open(f'{repository_root}/.gitmodules') as f: m = f.read()
758
+ with open(f'{repository_root}/.gitmodules', 'w') as f:
759
+ f.write(
760
+ m
761
+ .replace('url = git@github.com:', 'url = https://github.com/')
762
+ .replace('url = ../../../', 'url = https://github.com/RosettaCommons/')
763
+ .replace('url = ../../', 'url = https://github.com/RosettaCommons/')
764
+ .replace('url = ../', 'url = https://github.com/RosettaCommons/')
765
+ )
.rosetta-ci/tests/rfd.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # :noTabs=true:
4
+
5
+ # (c) Copyright Rosetta Commons Member Institutions.
6
+ # (c) This file is part of the Rosetta software suite and is made available under license.
7
+ # (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
8
+ # (c) For more information, see http://www.rosettacommons.org. Questions about this can be
9
+ # (c) addressed to University of Washington CoMotion, email: license@uw.edu.
10
+
11
+ ## @file rfd.py
12
+ ## @brief main test files for RFdiffusion
13
+ ## @author Sergey Lyskov
14
+
15
+
16
+ import imp
17
+ imp.load_source(__name__, '/'.join(__file__.split('/')[:-1]) + '/__init__.py') # A bit of Python magic here, what we trying to say is this: from __init__ import *, but init is calculated from file location
18
+
19
+ _api_version_ = '1.0'
20
+
21
+ import os, tempfile, shutil
22
+ import urllib.request
23
+
24
+
25
+ _models_urls_ = '''
26
+ http://files.ipd.uw.edu/pub/RFdiffusion/6f5902ac237024bdd0c176cb93063dc4/Base_ckpt.pt
27
+ http://files.ipd.uw.edu/pub/RFdiffusion/e29311f6f1bf1af907f9ef9f44b8328b/Complex_base_ckpt.pt
28
+ http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt
29
+ http://files.ipd.uw.edu/pub/RFdiffusion/74f51cfb8b440f50d70878e05361d8f0/InpaintSeq_ckpt.pt
30
+ http://files.ipd.uw.edu/pub/RFdiffusion/76d00716416567174cdb7ca96e208296/InpaintSeq_Fold_ckpt.pt
31
+ http://files.ipd.uw.edu/pub/RFdiffusion/5532d2e1f3a4738decd58b19d633b3c3/ActiveSite_ckpt.pt
32
+ http://files.ipd.uw.edu/pub/RFdiffusion/12fc204edeae5b57713c5ad7dcb97d39/Base_epoch8_ckpt.pt
33
+ http://files.ipd.uw.edu/pub/RFdiffusion/f572d396fae9206628714fb2ce00f72e/Complex_beta_ckpt.pt
34
+ http://files.ipd.uw.edu/pub/RFdiffusion/1befcb9b28e2f778f53d47f18b7597fa/RF_structure_prediction_weights.pt
35
+ '''.split()
36
+
37
+
38
+ def run_main_test_suite(repository_root, working_dir, platform, config, debug):
39
+ full_log = ''
40
+
41
+ python_environment = local_python_install(platform, config)
42
+
43
+ models_dir = repository_root + '/models'
44
+ if not os.path.isdir(models_dir): os.makedirs(models_dir)
45
+
46
+ for url in _models_urls_:
47
+ file_name = models_dir + '/' + url.split('/')[-1]
48
+ tmp_file_name = file_name + '.tmp'
49
+ if not os.path.isfile(file_name):
50
+ print(f'downloading {url}...')
51
+ full_log += f'downloading {url}...\n'
52
+ urllib.request.urlretrieve(url, tmp_file_name)
53
+ os.rename(tmp_file_name, file_name)
54
+
55
+ execute('unpacking ppi scaffolds...', f'cd {repository_root} && tar -xvf examples/ppi_scaffolds_subset.tar.gz -C examples')
56
+
57
+ with tempfile.TemporaryDirectory(dir=working_dir) as tmpdirname:
58
+ # tmpdirname = working_dir+'/.ve'
59
+ # if True:
60
+
61
+ #ve = setup_persistent_python_virtual_environment(python_environment, packages='numpy torch omegaconf scipy opt_einsum dgl')
62
+ #ve = setup_python_virtual_environment(working_dir+'/.ve', python_environment, packages='numpy torch omegaconf scipy opt_einsum dgl e3nn icecream pyrsistent wandb pynvml decorator jedi hydra-core')
63
+ ve = setup_python_virtual_environment(tmpdirname, python_environment, packages='numpy torch omegaconf scipy opt_einsum dgl e3nn icecream pyrsistent wandb pynvml decorator jedi hydra-core')
64
+
65
+ execute('Installing local se3-transformer package...', f'cd {repository_root}/env/SE3Transformer && {ve.bin}/pip3 install --editable .')
66
+ execute('Installing RFdiffusion package...', f'cd {repository_root} && {ve.bin}/pip3 install --editable .')
67
+
68
+ #res, output = execute('running unit tests...', f'{ve.activate} && cd {repository_root} && python -m unittest', return_='tuple', add_message_and_command_line_to_output=True)
69
+ #res, output = execute('running unit tests...', f'cd {repository_root} && {ve.bin}/pytest', return_='tuple')
70
+
71
+
72
+ results_file = f'{repository_root}/tests/.results.json'
73
+ if os.path.isfile(results_file): os.remove(results_file)
74
+
75
+ res, output = execute('running RFdiffusion tests...', f'{ve.activate} && cd {repository_root}/tests && python test_diffusion.py', return_='tuple', add_message_and_command_line_to_output=True)
76
+
77
+ if os.path.isfile(results_file):
78
+ with open(results_file) as f: sub_tests_reults = json.load(f)
79
+
80
+ state = _S_passed_
81
+ for r in sub_tests_reults.values():
82
+ if r[_StateKey_] == _S_failed_:
83
+ state = _S_failed_
84
+ break
85
+
86
+ else:
87
+ sub_tests_reults = {}
88
+ output += '\n\nEmpty sub-test results, marking test as `failed`...'
89
+ state = _S_failed_
90
+
91
+ shutil.move(f'{repository_root}/tests/outputs', f'{working_dir}/outputs')
92
+
93
+ for d in os.listdir(f'{repository_root}/tests'):
94
+ p = f'{repository_root}/tests/{d}'
95
+ if d.startswith('tests_') and os.path.isdir(p): shutil.rmtree(p)
96
+
97
+ results = {
98
+ _StateKey_ : state,
99
+ _LogKey_ : full_log + '\n' + output,
100
+ _ResultsKey_ : {
101
+ _TestsKey_ : sub_tests_reults,
102
+ },
103
+ }
104
+
105
+ return results
106
+
107
+
108
+
109
+ def run(test, repository_root, working_dir, platform, config, hpc_driver=None, verbose=False, debug=False):
110
+ if test == '': return run_main_test_suite(repository_root=repository_root, working_dir=working_dir, platform=platform, config=config, debug=debug)
111
+ else: raise BenchmarkError('Unknow scripts test: {}!'.format(test))
.rosetta-ci/tests/self.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # self test suite
2
+ These tests are design to help debug interface between testing server and Rosetta testing scripts
3
+
4
+ -----
5
+ ### python
6
+ Test Python platform support and functionality of local and persistent Python virtual environments
.rosetta-ci/tests/self.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # :noTabs=true:
4
+
5
+ # (c) Copyright Rosetta Commons Member Institutions.
6
+ # (c) This file is part of the Rosetta software suite and is made available under license.
7
+ # (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
8
+ # (c) For more information, see http://www.rosettacommons.org. Questions about this can be
9
+ # (c) addressed to University of Washington CoMotion, email: license@uw.edu.
10
+
11
+ ## @file dummy.py
12
+ ## @brief self-test and debug-aids tests
13
+ ## @author Sergey Lyskov
14
+
15
+ import os, os.path, shutil, re, string
16
+ import json
17
+
18
+ import random
19
+
20
+ import imp
21
+ imp.load_source(__name__, '/'.join(__file__.split('/')[:-1]) + '/__init__.py') # A bit of Python magic here, what we trying to say is this: from __init__ import *, but init is calculated from file location
22
+
23
+ _api_version_ = '1.0'
24
+
25
+
26
+ def run_state_test(repository_root, working_dir, platform, config):
27
+ revision_id = config['revision']
28
+ states = (_S_passed_, _S_failed_, _S_build_failed_, _S_script_failed_)
29
+ state = states[revision_id % len(states)]
30
+
31
+ return {_StateKey_ : state, _ResultsKey_ : {}, _LogKey_ : f'run_state_test: setting test state to {state!r}...' }
32
+
33
+
34
+ sub_test_description_template = '''\
35
+ # subtests_test test suite
36
+ These sub-test description is generated for 3/4 of sub-tests
37
+
38
+ -----
39
+ ### {name}
40
+ The warm time, had already disappeared like dust. Broken rain, fragment of light shadow, bring more pain to my heart...
41
+ -----
42
+ '''
43
+
44
+ def run_subtests_test(repository_root, working_dir, platform, config):
45
+ tests = {}
46
+ for i in range(16):
47
+ name = f's-{i:02}'
48
+ log = ('x'*63 + '\n') * 16 * 256 * i
49
+ s = i % 3
50
+ if s == 0: state = _S_passed_
51
+ elif s == 1: state = _S_failed_
52
+ else: state = _S_script_failed_
53
+
54
+ if i % 4:
55
+ os.mkdir( f'{working_dir}/{name}' )
56
+ with open(f'{working_dir}/{name}/description.md', 'w') as f: f.write( sub_test_description_template.format(**vars()) )
57
+
58
+ with open( f'{working_dir}/{name}/fantome.txt', 'w') as f: f.write('No one wants to hear the sequel to a fairytale\n')
59
+
60
+ tests[name] = { _StateKey_ : state, _LogKey_ : log, }
61
+
62
+ test_log = ('*'*63 + '\n') * 16 * 1024 * 16
63
+ return {_StateKey_ : _S_failed_, _ResultsKey_ : {_TestsKey_: tests}, _LogKey_ : test_log }
64
+
65
+
66
+ def run_regression_test(repository_root, working_dir, platform, config):
67
+ const = 'const'
68
+ volatile = 'volatile'
69
+ new = ''.join( random.sample( string.ascii_letters + string.digits, 8) )
70
+ oversized = 'oversized'
71
+
72
+ sub_tests = [const, volatile, new]
73
+
74
+ const_dir = working_dir + '/' + const
75
+ os.mkdir(const_dir)
76
+ with open(const_dir + '/const_data', 'w') as f: f.write( '\n'.join( (str(i) for i in range(32) ) ) )
77
+
78
+ volatile_dir = working_dir + '/' + volatile
79
+ os.mkdir(volatile_dir)
80
+ with open(volatile_dir + '/const_data', 'w') as f: f.write( '\n'.join( (str(i) for i in range(32, 64) ) ) )
81
+ with open(volatile_dir + '/volatile_data', 'w') as f: f.write( '\n'.join( ( ''.join(random.sample( string.ascii_letters + string.digits, 8) ) for i in range(32) ) ) )
82
+
83
+ new_dir = working_dir + '/' + new
84
+ os.mkdir(new_dir)
85
+ with open(new_dir + '/data', 'w') as f: f.write( '\n'.join( (str(i) for i in range(64)) ) )
86
+
87
+
88
+ new_dir = working_dir + '/' + oversized
89
+ os.mkdir(new_dir)
90
+ with open(new_dir + '/large', 'w') as f: f.write( ('x'*63 + '\n')*16*1024*256 +'extra')
91
+
92
+ return {_StateKey_ : _S_queued_for_comparison_, _ResultsKey_ : {}, _LogKey_ : f'sub-tests: {sub_tests!r}' }
93
+
94
+
95
+
96
+ def run_release_test(repository_root, working_dir, platform, config):
97
+ release_root = config['mounts'].get('release_root')
98
+
99
+ branch = config['branch']
100
+ revision = config['revision']
101
+
102
+ assert release_root, "config['release_root'] must be set!"
103
+
104
+ release_path = f'{release_root}/dummy'
105
+
106
+ if not os.path.isdir(release_path): os.makedirs(release_path)
107
+
108
+ with open(f'{release_path}/{branch}-{revision}.txt', 'w') as f: f.write('dummy release file\n')
109
+
110
+ return {_StateKey_ : _S_passed_, _ResultsKey_ : {}, _LogKey_ : f'Config release root set to: {release_root}'}
111
+
112
+
113
+
114
+ def run_python_test(repository_root, working_dir, platform, config):
115
+
116
+ import zlib, ssl
117
+
118
+ python_environment = local_python_install(platform, config)
119
+
120
+ if platform['python'][0] == '2': pass
121
+ else:
122
+
123
+ if platform['os'] == 'mac' and int( platform['python'].split('.')[1] ) > 6 :
124
+ # SSL certificate test
125
+ import urllib.request; urllib.request.urlopen('https://benchmark.graylab.jhu.edu')
126
+
127
+ ves = [
128
+ setup_persistent_python_virtual_environment(python_environment, packages='colr dice xdice pdp11games'),
129
+ setup_python_virtual_environment(working_dir, python_environment, packages='colr dice xdice pdp11games'),
130
+ ]
131
+
132
+ for ve in ves:
133
+ commands = [
134
+ 'import colr, dice, xdice, pdp11games',
135
+ ]
136
+
137
+ if platform['os'] == 'mac' and int( platform['python'].split('.')[1] ) > 6 :
138
+ # SSL certificate test
139
+ commands.append('import urllib.request; urllib.request.urlopen("https://benchmark.graylab.jhu.edu/queue")')
140
+
141
+ for command in commands:
142
+ execute('Testing local Python virtual enviroment...', f"{ve.activate} && {ve.python} -c '{command}'")
143
+ execute('Testing local Python virtual enviroment...', f"{ve.activate} && python -c '{command}'")
144
+
145
+
146
+
147
+ return {_StateKey_ : _S_passed_, _ResultsKey_ : {}, _LogKey_ : f'Done!'}
148
+
149
+
150
+
151
+ def compare(test, results, files_path, previous_results, previous_files_path):
152
+ """
153
+ Compare the results of two tests run (new vs. previous) for regression test
154
+ Take two dict and two paths
155
+ Must return standard dict with results
156
+
157
+ :param test: str
158
+ :param results: dict
159
+ :param files_path: str
160
+ :param previous_results: dict
161
+ :param previous_files_path: str
162
+ :rtype: dict
163
+ """
164
+ ignore_files = []
165
+
166
+ results = dict(tests={}, summary=dict(total=0, failed=0, failed_tests=[])) # , config={}
167
+
168
+ if previous_files_path:
169
+ for test in os.listdir(files_path):
170
+ if os.path.isdir(files_path + '/' + test):
171
+ exclude = ''.join([' --exclude="{}"'.format(f) for f in ignore_files] ) + ' --exclude="*.ignore"'
172
+ res, brief_diff = execute('Comparing {}...'.format(test), 'diff -rq {exclude} {0}/{test} {1}/{test}'.format(previous_files_path, files_path, test=test, exclude=exclude), return_='tuple')
173
+ res, full_diff = execute('Comparing {}...'.format(test), 'diff -r {exclude} {0}/{test} {1}/{test}'.format(previous_files_path, files_path, test=test, exclude=exclude), return_='tuple')
174
+ diff = 'Brief Diff:\n' + brief_diff + ( ('\n\nFull Diff:\n' + full_diff[:1024*1024*1]) if full_diff != brief_diff else '' )
175
+
176
+ state = _S_failed_ if res else _S_passed_
177
+ results['tests'][test] = {_StateKey_: state, _LogKey_: diff if state != _S_passed_ else ''}
178
+
179
+ results['summary']['total'] += 1
180
+ if res: results['summary']['failed'] += 1; results['summary']['failed_tests'].append(test)
181
+
182
+ else: # no previous tests case, returning 'passed' for all sub_tests
183
+ for test in os.listdir(files_path):
184
+ if os.path.isdir(files_path + '/' + test):
185
+ results['tests'][test] = {_StateKey_: _S_passed_, _LogKey_: 'First run, no previous results available. Skipping comparison...\n'}
186
+ results['summary']['total'] += 1
187
+
188
+ for test in os.listdir(files_path):
189
+ if os.path.isdir(files_path + '/' + test):
190
+ if os.path.isfile(files_path+'/'+test+'/.test_did_not_run.log') or os.path.isfile(files_path+'/'+test+'/.test_got_timeout_kill.log'):
191
+ results['tests'][test][_StateKey_] = _S_script_failed_
192
+ results['tests'][test][_LogKey_] += '\nCompare(...): Marking as "Script failed" due to presense of .test_did_not_run.log or .test_got_timeout_kill.log file!\n'
193
+ if test not in results['summary']['failed_tests']:
194
+ results['summary']['failed'] += 1
195
+ results['summary']['failed_tests'].append(test)
196
+
197
+ state = _S_failed_ if results['summary']['failed'] else _S_passed_
198
+
199
+ return {_StateKey_: state, _LogKey_: 'Comparison dummy log...', _ResultsKey_: results}
200
+
201
+
202
+ def run(test, repository_root, working_dir, platform, config, hpc_driver=None, verbose=False, debug=False):
203
+ if test == 'state': return run_state_test (repository_root, working_dir, platform, config)
204
+ elif test == 'regression': return run_regression_test (repository_root, working_dir, platform, config)
205
+ elif test == 'subtests': return run_subtests_test (repository_root, working_dir, platform, config)
206
+ elif test == 'release': return run_release_test (repository_root, working_dir, platform, config)
207
+ elif test == 'python': return run_python_test (repository_root, working_dir, platform, config)
208
+
209
+ else: raise BenchmarkError(f'Dummy test script does not support run with test={test!r}!')
config/inference/base.yaml ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base inference Configuration.
2
+
3
+ inference:
4
+ input_pdb: null
5
+ num_designs: 10
6
+ design_startnum: 0
7
+ ckpt_override_path: null
8
+ symmetry: null
9
+ recenter: True
10
+ radius: 10.0
11
+ model_only_neighbors: False
12
+ output_prefix: samples/design
13
+ write_trajectory: True
14
+ scaffold_guided: False
15
+ model_runner: SelfConditioning
16
+ cautious: True
17
+ align_motif: True
18
+ symmetric_self_cond: True
19
+ final_step: 1
20
+ deterministic: False
21
+ trb_save_ckpt_path: null
22
+ schedule_directory_path: null
23
+ model_directory_path: null
24
+
25
+ contigmap:
26
+ contigs: null
27
+ inpaint_seq: null
28
+ provide_seq: null
29
+ length: null
30
+
31
+ model:
32
+ n_extra_block: 4
33
+ n_main_block: 32
34
+ n_ref_block: 4
35
+ d_msa: 256
36
+ d_msa_full: 64
37
+ d_pair: 128
38
+ d_templ: 64
39
+ n_head_msa: 8
40
+ n_head_pair: 4
41
+ n_head_templ: 4
42
+ d_hidden: 32
43
+ d_hidden_templ: 32
44
+ p_drop: 0.15
45
+ SE3_param_full:
46
+ num_layers: 1
47
+ num_channels: 32
48
+ num_degrees: 2
49
+ n_heads: 4
50
+ div: 4
51
+ l0_in_features: 8
52
+ l0_out_features: 8
53
+ l1_in_features: 3
54
+ l1_out_features: 2
55
+ num_edge_features: 32
56
+ SE3_param_topk:
57
+ num_layers: 1
58
+ num_channels: 32
59
+ num_degrees: 2
60
+ n_heads: 4
61
+ div: 4
62
+ l0_in_features: 64
63
+ l0_out_features: 64
64
+ l1_in_features: 3
65
+ l1_out_features: 2
66
+ num_edge_features: 64
67
+ freeze_track_motif: False
68
+ use_motif_timestep: False
69
+
70
+ diffuser:
71
+ T: 50
72
+ b_0: 1e-2
73
+ b_T: 7e-2
74
+ schedule_type: linear
75
+ so3_type: igso3
76
+ crd_scale: 0.25
77
+ partial_T: null
78
+ so3_schedule_type: linear
79
+ min_b: 1.5
80
+ max_b: 2.5
81
+ min_sigma: 0.02
82
+ max_sigma: 1.5
83
+
84
+ denoiser:
85
+ noise_scale_ca: 1
86
+ final_noise_scale_ca: 1
87
+ ca_noise_schedule_type: constant
88
+ noise_scale_frame: 1
89
+ final_noise_scale_frame: 1
90
+ frame_noise_schedule_type: constant
91
+
92
+ ppi:
93
+ hotspot_res: null
94
+
95
+ potentials:
96
+ guiding_potentials: null
97
+ guide_scale: 10
98
+ guide_decay: constant
99
+ olig_inter_all : null
100
+ olig_intra_all : null
101
+ olig_custom_contact : null
102
+ substrate: null
103
+
104
+ contig_settings:
105
+ ref_idx: null
106
+ hal_idx: null
107
+ idx_rf: null
108
+ inpaint_seq_tensor: null
109
+
110
+ preprocess:
111
+ sidechain_input: False
112
+ motif_sidechain_input: True
113
+ d_t1d: 22
114
+ d_t2d: 44
115
+ prob_self_cond: 0.0
116
+ str_self_cond: False
117
+ predict_previous: False
118
+
119
+ logging:
120
+ inputs: False
121
+
122
+ scaffoldguided:
123
+ scaffoldguided: False
124
+ target_pdb: False
125
+ target_path: null
126
+ scaffold_list: null
127
+ scaffold_dir: null
128
+ sampled_insertion: 0
129
+ sampled_N: 0
130
+ sampled_C: 0
131
+ ss_mask: 0
132
+ systematic: False
133
+ target_ss: null
134
+ target_adj: null
135
+ mask_loops: True
136
+ contig_crop: null
config/inference/symmetry.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Config for sampling symmetric assemblies.
2
+
3
+ defaults:
4
+ - base
5
+
6
+ inference:
7
+ # Symmetry to sample
8
+ # Available symmetries:
9
+ # - Cyclic symmetry (C_n) # call as c5
10
+ # - Dihedral symmetry (D_n) # call as d5
11
+ # - Tetrahedral symmetry # call as tetrahedral
12
+ # - Octahedral symmetry # call as octahedral
13
+ # - Icosahedral symmetry # call as icosahedral
14
+ symmetry: c2
15
+
16
+ # Set to true for computational efficiency
17
+ # to avoid memory overhead of modeling all subunits.
18
+ model_only_neighbors: False
19
+
20
+ # Output directory of samples.
21
+ output_prefix: samples/c2
22
+
23
+ contigmap:
24
+ # Specify a single integer value to sample unconditionally.
25
+ # Must be evenly divisible by the number of chains in the symmetry.
26
+ contigs: ['100']
docker/Dockerfile ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Usage:
2
+ # git clone https://github.com/RosettaCommons/RFdiffusion.git
3
+ # cd RFdiffusion
4
+ # docker build -f docker/Dockerfile -t rfdiffusion .
5
+ # mkdir $HOME/inputs $HOME/outputs $HOME/models
6
+ # bash scripts/download_models.sh $HOME/models
7
+ # wget -P $HOME/inputs https://files.rcsb.org/view/5TPN.pdb
8
+
9
+ # docker run -it --rm --gpus all \
10
+ # -v $HOME/models:$HOME/models \
11
+ # -v $HOME/inputs:$HOME/inputs \
12
+ # -v $HOME/outputs:$HOME/outputs \
13
+ # rfdiffusion \
14
+ # inference.output_prefix=$HOME/outputs/motifscaffolding \
15
+ # inference.model_directory_path=$HOME/models \
16
+ # inference.input_pdb=$HOME/inputs/5TPN.pdb \
17
+ # inference.num_designs=3 \
18
+ # 'contigmap.contigs=[10-40/A163-181/10-40]'
19
+
20
+ FROM nvcr.io/nvidia/cuda:11.6.2-cudnn8-runtime-ubuntu20.04
21
+
22
+ COPY . /app/RFdiffusion/
23
+
24
+ RUN apt-get -q update \
25
+ && DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
26
+ git \
27
+ python3.9 \
28
+ python3-pip \
29
+ && python3.9 -m pip install -q -U --no-cache-dir pip \
30
+ && rm -rf /var/lib/apt/lists/* \
31
+ && apt-get autoremove -y \
32
+ && apt-get clean \
33
+ && pip install -q --no-cache-dir \
34
+ dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html \
35
+ torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 \
36
+ e3nn==0.3.3 \
37
+ wandb==0.12.0 \
38
+ pynvml==11.0.0 \
39
+ git+https://github.com/NVIDIA/dllogger#egg=dllogger \
40
+ decorator==5.1.0 \
41
+ hydra-core==1.3.2 \
42
+ pyrsistent==0.19.3 \
43
+ /app/RFdiffusion/env/SE3Transformer \
44
+ && pip install --no-cache-dir /app/RFdiffusion --no-deps
45
+
46
+ WORKDIR /app/RFdiffusion
47
+
48
+ ENV DGLBACKEND="pytorch"
49
+
50
+ ENTRYPOINT ["python3.9", "scripts/run_inference.py"]
env/SE3Transformer/.dockerignore ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .Trash-0
2
+ .git
3
+ data/
4
+ .DS_Store
5
+ *wandb/
6
+ *.pt
7
+ *.swp
8
+
9
+ # added by FAFU
10
+ .idea/
11
+ cache/
12
+ downloaded/
13
+ *.lprof
14
+
15
+ # Byte-compiled / optimized / DLL files
16
+ __pycache__/
17
+ *.py[cod]
18
+ *$py.class
19
+
20
+ # C extensions
21
+ *.so
22
+
23
+ # Distribution / packaging
24
+ .Python
25
+ build/
26
+ develop-eggs/
27
+ dist/
28
+ downloads/
29
+ eggs/
30
+ .eggs/
31
+ lib/
32
+ lib64/
33
+ parts/
34
+ sdist/
35
+ var/
36
+ wheels/
37
+ *.egg-info/
38
+ .installed.cfg
39
+ *.egg
40
+ MANIFEST
41
+
42
+ # PyInstaller
43
+ # Usually these files are written by a python script from a template
44
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
45
+ *.manifest
46
+ *.spec
47
+
48
+ # Installer logs
49
+ pip-log.txt
50
+ pip-delete-this-directory.txt
51
+
52
+ # Unit test / coverage reports
53
+ htmlcov/
54
+ .tox/
55
+ .coverage
56
+ .coverage.*
57
+ .cache
58
+ nosetests.xml
59
+ coverage.xml
60
+ *.cover
61
+ .hypothesis/
62
+ .pytest_cache/
63
+
64
+ # Translations
65
+ *.mo
66
+ *.pot
67
+
68
+ # Django stuff:
69
+ *.log
70
+ local_settings.py
71
+ db.sqlite3
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # pyenv
90
+ .python-version
91
+
92
+ # celery beat schedule file
93
+ celerybeat-schedule
94
+
95
+ # SageMath parsed files
96
+ *.sage.py
97
+
98
+ # Environments
99
+ .env
100
+ .venv
101
+ env/
102
+ venv/
103
+ ENV/
104
+ env.bak/
105
+ venv.bak/
106
+
107
+ # Spyder project settings
108
+ .spyderproject
109
+ .spyproject
110
+
111
+ # Rope project settings
112
+ .ropeproject
113
+
114
+ # mkdocs documentation
115
+ /site
116
+
117
+ # mypy
118
+ .mypy_cache/
119
+
120
+ **/benchmark
121
+ **/results
122
+ *.pkl
123
+ *.log
env/SE3Transformer/.gitignore ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ .DS_Store
3
+ *wandb/
4
+ *.pt
5
+ *.swp
6
+
7
+ # added by FAFU
8
+ .idea/
9
+ cache/
10
+ downloaded/
11
+ *.lprof
12
+
13
+ # Byte-compiled / optimized / DLL files
14
+ __pycache__/
15
+ *.py[cod]
16
+ *$py.class
17
+
18
+ # C extensions
19
+ *.so
20
+
21
+ # Distribution / packaging
22
+ .Python
23
+ build/
24
+ develop-eggs/
25
+ dist/
26
+ downloads/
27
+ eggs/
28
+ .eggs/
29
+ lib/
30
+ lib64/
31
+ parts/
32
+ sdist/
33
+ var/
34
+ wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .coverage
54
+ .coverage.*
55
+ .cache
56
+ nosetests.xml
57
+ coverage.xml
58
+ *.cover
59
+ .hypothesis/
60
+ .pytest_cache/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+
71
+ # Flask stuff:
72
+ instance/
73
+ .webassets-cache
74
+
75
+ # Scrapy stuff:
76
+ .scrapy
77
+
78
+ # Sphinx documentation
79
+ docs/_build/
80
+
81
+ # PyBuilder
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # pyenv
88
+ .python-version
89
+
90
+ # celery beat schedule file
91
+ celerybeat-schedule
92
+
93
+ # SageMath parsed files
94
+ *.sage.py
95
+
96
+ # Environments
97
+ .env
98
+ .venv
99
+ env/
100
+ venv/
101
+ ENV/
102
+ env.bak/
103
+ venv.bak/
104
+
105
+ # Spyder project settings
106
+ .spyderproject
107
+ .spyproject
108
+
109
+ # Rope project settings
110
+ .ropeproject
111
+
112
+ # mkdocs documentation
113
+ /site
114
+
115
+ # mypy
116
+ .mypy_cache/
117
+
118
+ **/benchmark
119
+ **/results
120
+ *.pkl
121
+ *.log
env/SE3Transformer/Dockerfile ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ # run docker daemon with --default-runtime=nvidia for GPU detection during build
25
+ # multistage build for DGL with CUDA and FP16
26
+
27
+ ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.07-py3
28
+
29
+ FROM ${FROM_IMAGE_NAME} AS dgl_builder
30
+
31
+ ENV DEBIAN_FRONTEND=noninteractive
32
+ RUN apt-get update \
33
+ && apt-get install -y git build-essential python3-dev make cmake \
34
+ && rm -rf /var/lib/apt/lists/*
35
+ WORKDIR /dgl
36
+ RUN git clone --branch v0.7.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git .
37
+ RUN sed -i 's/"35 50 60 70"/"60 70 80"/g' cmake/modules/CUDA.cmake
38
+ WORKDIR build
39
+ RUN cmake -DUSE_CUDA=ON -DUSE_FP16=ON ..
40
+ RUN make -j8
41
+
42
+
43
+ FROM ${FROM_IMAGE_NAME}
44
+
45
+ RUN rm -rf /workspace/*
46
+ WORKDIR /workspace/se3-transformer
47
+
48
+ # copy built DGL and install it
49
+ COPY --from=dgl_builder /dgl ./dgl
50
+ RUN cd dgl/python && python setup.py install && cd ../.. && rm -rf dgl
51
+
52
+ ADD requirements.txt .
53
+ RUN pip install --no-cache-dir --upgrade --pre pip
54
+ RUN pip install --no-cache-dir -r requirements.txt
55
+ ADD . .
56
+
57
+ ENV DGLBACKEND=pytorch
58
+ ENV OMP_NUM_THREADS=1
env/SE3Transformer/LICENSE ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Copyright 2021 NVIDIA CORPORATION & AFFILIATES
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
+
5
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
+
7
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
env/SE3Transformer/NOTICE ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ SE(3)-Transformer PyTorch
2
+
3
+ This repository includes software from https://github.com/FabianFuchsML/se3-transformer-public
4
+ licensed under the MIT License.
5
+
6
+ This repository includes software from https://github.com/lucidrains/se3-transformer-pytorch
7
+ licensed under the MIT License.
env/SE3Transformer/README.md ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SE(3)-Transformers For PyTorch
2
+
3
+ This repository provides a script and recipe to train the SE(3)-Transformer model to achieve state-of-the-art accuracy. The content of this repository is tested and maintained by NVIDIA.
4
+
5
+ ## Table Of Contents
6
+ - [Model overview](#model-overview)
7
+ * [Model architecture](#model-architecture)
8
+ * [Default configuration](#default-configuration)
9
+ * [Feature support matrix](#feature-support-matrix)
10
+ * [Features](#features)
11
+ * [Mixed precision training](#mixed-precision-training)
12
+ * [Enabling mixed precision](#enabling-mixed-precision)
13
+ * [Enabling TF32](#enabling-tf32)
14
+ * [Glossary](#glossary)
15
+ - [Setup](#setup)
16
+ * [Requirements](#requirements)
17
+ - [Quick Start Guide](#quick-start-guide)
18
+ - [Advanced](#advanced)
19
+ * [Scripts and sample code](#scripts-and-sample-code)
20
+ * [Parameters](#parameters)
21
+ * [Command-line options](#command-line-options)
22
+ * [Getting the data](#getting-the-data)
23
+ * [Dataset guidelines](#dataset-guidelines)
24
+ * [Multi-dataset](#multi-dataset)
25
+ * [Training process](#training-process)
26
+ * [Inference process](#inference-process)
27
+ - [Performance](#performance)
28
+ * [Benchmarking](#benchmarking)
29
+ * [Training performance benchmark](#training-performance-benchmark)
30
+ * [Inference performance benchmark](#inference-performance-benchmark)
31
+ * [Results](#results)
32
+ * [Training accuracy results](#training-accuracy-results)
33
+ * [Training accuracy: NVIDIA DGX A100 (8x A100 80GB)](#training-accuracy-nvidia-dgx-a100-8x-a100-80gb)
34
+ * [Training accuracy: NVIDIA DGX-1 (8x V100 16GB)](#training-accuracy-nvidia-dgx-1-8x-v100-16gb)
35
+ * [Training stability test](#training-stability-test)
36
+ * [Training performance results](#training-performance-results)
37
+ * [Training performance: NVIDIA DGX A100 (8x A100 80GB)](#training-performance-nvidia-dgx-a100-8x-a100-80gb)
38
+ * [Training performance: NVIDIA DGX-1 (8x V100 16GB)](#training-performance-nvidia-dgx-1-8x-v100-16gb)
39
+ * [Inference performance results](#inference-performance-results)
40
+ * [Inference performance: NVIDIA DGX A100 (1x A100 80GB)](#inference-performance-nvidia-dgx-a100-1x-a100-80gb)
41
+ * [Inference performance: NVIDIA DGX-1 (1x V100 16GB)](#inference-performance-nvidia-dgx-1-1x-v100-16gb)
42
+ - [Release notes](#release-notes)
43
+ * [Changelog](#changelog)
44
+ * [Known issues](#known-issues)
45
+
46
+
47
+
48
+ ## Model overview
49
+
50
+
51
+ The **SE(3)-Transformer** is a Graph Neural Network using a variant of [self-attention](https://arxiv.org/abs/1706.03762v5) for 3D points and graphs processing.
52
+ This model is [equivariant](https://en.wikipedia.org/wiki/Equivariant_map) under [continuous 3D roto-translations](https://en.wikipedia.org/wiki/Euclidean_group), meaning that when the inputs (graphs or sets of points) rotate in 3D space (or more generally experience a [proper rigid transformation](https://en.wikipedia.org/wiki/Rigid_transformation)), the model outputs either stay invariant or transform with the input.
53
+ A mathematical guarantee of equivariance is important to ensure stable and predictable performance in the presence of nuisance transformations of the data input and when the problem has some inherent symmetries we want to exploit.
54
+
55
+
56
+ The model is based on the following publications:
57
+ - [SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks](https://arxiv.org/abs/2006.10503) (NeurIPS 2020) by Fabian B. Fuchs, Daniel E. Worrall, et al.
58
+ - [Tensor field networks: Rotation- and translation-equivariant neural networks for 3D point clouds](https://arxiv.org/abs/1802.08219) by Nathaniel Thomas, Tess Smidt, et al.
59
+
60
+ A follow-up paper explains how this model can be used iteratively, for example, to predict or refine protein structures:
61
+
62
+ - [Iterative SE(3)-Transformers](https://arxiv.org/abs/2102.13419) by Fabian B. Fuchs, Daniel E. Worrall, et al.
63
+
64
+ Just like [the official implementation](https://github.com/FabianFuchsML/se3-transformer-public), this implementation uses [PyTorch](https://pytorch.org/) and the [Deep Graph Library (DGL)](https://www.dgl.ai/).
65
+
66
+ The main differences between this implementation of SE(3)-Transformers and the official one are the following:
67
+
68
+ - Training and inference support for multiple GPUs
69
+ - Training and inference support for [Mixed Precision](https://arxiv.org/abs/1710.03740)
70
+ - The [QM9 dataset from DGL](https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset) is used and automatically downloaded
71
+ - Significantly increased throughput
72
+ - Significantly reduced memory consumption
73
+ - The use of layer normalization in the fully connected radial profile layers is an option (`--use_layer_norm`), off by default
74
+ - The use of equivariant normalization between attention layers is an option (`--norm`), off by default
75
+ - The [spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonic) and [Clebsch–Gordan coefficients](https://en.wikipedia.org/wiki/Clebsch%E2%80%93Gordan_coefficients), used to compute bases matrices, are computed with the [e3nn library](https://e3nn.org/)
76
+
77
+
78
+
79
+ This model enables you to predict quantum chemical properties of small organic molecules in the [QM9 dataset](https://www.nature.com/articles/sdata201422).
80
+ In this case, the exploited symmetry is that these properties do not depend on the orientation or position of the molecules in space.
81
+
82
+
83
+ This model is trained with mixed precision using Tensor Cores on NVIDIA Volta, NVIDIA Turing, and the NVIDIA Ampere GPU architectures. Therefore, researchers can get results up to 1.5x faster than training without Tensor Cores while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
84
+
85
+ ### Model architecture
86
+
87
+ The model consists of stacked layers of equivariant graph self-attention and equivariant normalization.
88
+ Lastly, a Tensor Field Network convolution is applied to obtain invariant features. Graph pooling (mean or max over the nodes) is applied to these features, and the result is fed to a final MLP to get scalar predictions.
89
+
90
+ In this setup, the model is a graph-to-scalar network. The pooling can be removed to obtain a graph-to-graph network, and the final TFN can be modified to output features of any type (invariant scalars, 3D vectors, ...).
91
+
92
+
93
+ ![Model high-level architecture](./images/se3-transformer.png)
94
+
95
+
96
+ ### Default configuration
97
+
98
+
99
+ SE(3)-Transformers introduce a self-attention layer for graphs that is equivariant to 3D roto-translations. It achieves this by leveraging Tensor Field Networks to build attention weights that are invariant and attention values that are equivariant.
100
+ Combining the equivariant values with the invariant weights gives rise to an equivariant output. This output is normalized while preserving equivariance thanks to equivariant normalization layers operating on feature norms.
101
+
102
+
103
+ The following features were implemented in this model:
104
+
105
+ - Support for edge features of any degree (1D, 3D, 5D, ...), whereas the official implementation only supports scalar invariant edge features (degree 0). Edge features with a degree greater than one are
106
+ concatenated to node features of the same degree. This is required in order to reproduce published results on point cloud processing.
107
+ - Data-parallel multi-GPU training (DDP)
108
+ - Mixed precision training (autocast, gradient scaling)
109
+ - Gradient accumulation
110
+ - Model checkpointing
111
+
112
+
113
+ The following performance optimizations were implemented in this model:
114
+
115
+
116
+ **General optimizations**
117
+
118
+ - The option is provided to precompute bases at the beginning of the training instead of computing them at the beginning of each forward pass (`--precompute_bases`)
119
+ - The bases computation is just-in-time (JIT) compiled with `torch.jit.script`
120
+ - The Clebsch-Gordon coefficients are cached in RAM
121
+
122
+
123
+ **Tensor Field Network optimizations**
124
+
125
+ - The last layer of each radial profile network does not add any bias in order to avoid large broadcasting operations
126
+ - The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers
127
+ - When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions)
128
+ - Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met
129
+ - A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`)
130
+
131
+ **Self-attention optimizations**
132
+
133
+ - Attention keys and values are computed by a single partial TFN graph convolution in each attention layer instead of two
134
+ - Graph operations for different output degrees may be fused together if conditions are met
135
+
136
+
137
+ **Normalization optimizations**
138
+
139
+ - The equivariant normalization layer is optimized from multiple layer normalizations to a group normalization on fused norms when certain conditions are met
140
+
141
+
142
+
143
+ Competitive training results and analysis are provided for the following hyperparameters (identical to the ones in the original publication):
144
+ - Number of layers: 7
145
+ - Number of degrees: 4
146
+ - Number of channels: 32
147
+ - Number of attention heads: 8
148
+ - Channels division: 2
149
+ - Use of equivariant normalization: true
150
+ - Use of layer normalization: true
151
+ - Pooling: max
152
+
153
+
154
+ ### Feature support matrix
155
+
156
+ This model supports the following features::
157
+
158
+ | Feature | SE(3)-Transformer
159
+ |-----------------------|--------------------------
160
+ |Automatic mixed precision (AMP) | Yes
161
+ |Distributed data parallel (DDP) | Yes
162
+
163
+ #### Features
164
+
165
+
166
+ **Distributed data parallel (DDP)**
167
+
168
+ [DistributedDataParallel (DDP)](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implements data parallelism at the module level that can run across multiple GPUs or machines.
169
+
170
+ **Automatic Mixed Precision (AMP)**
171
+
172
+ This implementation uses the native PyTorch AMP implementation of mixed precision training. It allows us to use FP16 training with FP32 master weights by modifying just a few lines of code. A detailed explanation of mixed precision can be found in the next section.
173
+
174
+ ### Mixed precision training
175
+
176
+ Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in NVIDIA Volta, and following with both the NVIDIA Turing and NVIDIA Ampere Architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using [mixed precision training](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) previously required two steps:
177
+ 1. Porting the model to use the FP16 data type where appropriate.
178
+ 2. Adding loss scaling to preserve small gradient values.
179
+
180
+ AMP enables mixed precision training on NVIDIA Volta, NVIDIA Turing, and NVIDIA Ampere GPU architectures automatically. The PyTorch framework code makes all necessary model changes internally.
181
+
182
+ For information about:
183
+ - How to train using mixed precision, refer to the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) documentation.
184
+ - Techniques used for mixed precision training, refer to the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.
185
+ - APEX tools for mixed precision training, refer to the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/).
186
+
187
+ #### Enabling mixed precision
188
+
189
+ Mixed precision is enabled in PyTorch by using the native [Automatic Mixed Precision package](https://pytorch.org/docs/stable/amp.html), which casts variables to half-precision upon retrieval while storing variables in single-precision format. Furthermore, to preserve small gradient magnitudes in backpropagation, a [loss scaling](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#lossscaling) step must be included when applying gradients. In PyTorch, loss scaling can be applied automatically using a `GradScaler`.
190
+ Automatic Mixed Precision makes all the adjustments internally in PyTorch, providing two benefits over manual operations. First, programmers need not modify network model code, reducing development and maintenance effort. Second, using AMP maintains forward and backward compatibility with all the APIs for defining and running PyTorch models.
191
+
192
+ To enable mixed precision, you can simply use the `--amp` flag when running the training or inference scripts.
193
+
194
+ #### Enabling TF32
195
+
196
+ TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math, also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on NVIDIA Volta GPUs.
197
+
198
+ TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models that require a high dynamic range for weights or activations.
199
+
200
+ For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) blog post.
201
+
202
+ TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by default.
203
+
204
+
205
+
206
+ ### Glossary
207
+
208
+ **Degree (type)**
209
+
210
+ In the model, every feature (input, output and hidden) transforms in an equivariant way in relation to the input graph. When we define a feature, we need to choose, in addition to the number of channels, which transformation rule it obeys.
211
+
212
+ The degree or type of a feature is a positive integer that describes how this feature transforms when the input rotates in 3D.
213
+
214
+ This is related to [irreducible representations](https://en.wikipedia.org/wiki/Irreducible_representation) of different rotation orders.
215
+
216
+ The degree of a feature determines its dimensionality. A type-d feature has a dimensionality of 2d+1.
217
+
218
+ Some common examples include:
219
+ - Degree 0: 1D scalars invariant to rotation
220
+ - Degree 1: 3D vectors that rotate according to 3D rotation matrices
221
+ - Degree 2: 5D vectors that rotate according to 5D [Wigner-D matrices](https://en.wikipedia.org/wiki/Wigner_D-matrix). These can represent symmetric traceless 3x3 matrices.
222
+
223
+ **Fiber**
224
+
225
+ A fiber can be viewed as a representation of a set of features of different types or degrees (positive integers), where each feature type transforms according to its rule.
226
+
227
+ In this repository, a fiber can be seen as a dictionary with degrees as keys and numbers of channels as values.
228
+
229
+ **Multiplicity**
230
+
231
+ The multiplicity of a feature of a given type is the number of channels of this feature.
232
+
233
+ **Tensor Field Network**
234
+
235
+ A [Tensor Field Network](https://arxiv.org/abs/1802.08219) is a kind of equivariant graph convolution that can combine features of different degrees and produce new ones while preserving equivariance thanks to [tensor products](https://en.wikipedia.org/wiki/Tensor_product).
236
+
237
+ **Equivariance**
238
+
239
+ [Equivariance](https://en.wikipedia.org/wiki/Equivariant_map) is a property of a function of model stating that applying a symmetry transformation to the input and then computing the function produces the same result as computing the function and then applying the transformation to the output.
240
+
241
+ In the case of SE(3)-Transformer, the symmetry group is the group of continuous roto-translations (SE(3)).
242
+
243
+ ## Setup
244
+
245
+ The following section lists the requirements that you need to meet in order to start training the SE(3)-Transformer model.
246
+
247
+ ### Requirements
248
+
249
+ This repository contains a Dockerfile which extends the PyTorch 21.07 NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
250
+ - [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
251
+ - PyTorch 21.07+ NGC container
252
+ - Supported GPUs:
253
+ - [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
254
+ - [NVIDIA Turing architecture](https://www.nvidia.com/en-us/design-visualization/technologies/turing-architecture/)
255
+ - [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
256
+
257
+ For more information about how to get started with NGC containers, refer to the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
258
+ - [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html)
259
+ - [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry)
260
+ - [Running PyTorch](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/running.html#running)
261
+
262
+ For those unable to use the PyTorch NGC container to set up the required environment or create your own container, refer to the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
263
+
264
+ ## Quick Start Guide
265
+
266
+ To train your model using mixed or TF32 precision with Tensor Cores or FP32, perform the following steps using the default parameters of the SE(3)-Transformer model on the QM9 dataset. For the specifics concerning training and inference, refer to the [Advanced](#advanced) section.
267
+
268
+ 1. Clone the repository.
269
+ ```
270
+ git clone https://github.com/NVIDIA/DeepLearningExamples
271
+ cd DeepLearningExamples/PyTorch/DrugDiscovery/SE3Transformer
272
+ ```
273
+
274
+ 2. Build the `se3-transformer` PyTorch NGC container.
275
+ ```
276
+ docker build -t se3-transformer .
277
+ ```
278
+
279
+ 3. Start an interactive session in the NGC container to run training/inference.
280
+ ```
281
+ mkdir -p results
282
+ docker run -it --runtime=nvidia --shm-size=8g --ulimit memlock=-1 --ulimit stack=67108864 --rm -v ${PWD}/results:/results se3-transformer:latest
283
+ ```
284
+
285
+ 4. Start training.
286
+ ```
287
+ bash scripts/train.sh
288
+ ```
289
+
290
+ 5. Start inference/predictions.
291
+ ```
292
+ bash scripts/predict.sh
293
+ ```
294
+
295
+
296
+ Now that you have your model trained and evaluated, you can choose to compare your training results with our [Training accuracy results](#training-accuracy-results). You can also choose to benchmark your performance to [Training performance benchmark](#training-performance-results) or [Inference performance benchmark](#inference-performance-results). Following the steps in these sections will ensure that you achieve the same accuracy and performance results as stated in the [Results](#results) section.
297
+
298
+ ## Advanced
299
+
300
+ The following sections provide greater details of the dataset, running training and inference, and the training results.
301
+
302
+ ### Scripts and sample code
303
+
304
+ In the root directory, the most important files are:
305
+ - `Dockerfile`: container with the basic set of dependencies to run SE(3)-Transformers
306
+ - `requirements.txt`: set of extra requirements to run SE(3)-Transformers
307
+ - `se3_transformer/data_loading/qm9.py`: QM9 data loading and preprocessing, as well as bases precomputation
308
+ - `se3_transformer/model/layers/`: directory containing model architecture layers
309
+ - `se3_transformer/model/transformer.py`: main Transformer module
310
+ - `se3_transformer/model/basis.py`: logic for computing bases matrices
311
+ - `se3_transformer/runtime/training.py`: training script, to be run as a python module
312
+ - `se3_transformer/runtime/inference.py`: inference script, to be run as a python module
313
+ - `se3_transformer/runtime/metrics.py`: MAE metric with support for multi-GPU synchronization
314
+ - `se3_transformer/runtime/loggers.py`: [DLLogger](https://github.com/NVIDIA/dllogger) and [W&B](wandb.ai/) loggers
315
+
316
+
317
+ ### Parameters
318
+
319
+ The complete list of the available parameters for the `training.py` script contains:
320
+
321
+ **General**
322
+
323
+ - `--epochs`: Number of training epochs (default: `100` for single-GPU)
324
+ - `--batch_size`: Batch size (default: `240`)
325
+ - `--seed`: Set a seed globally (default: `None`)
326
+ - `--num_workers`: Number of dataloading workers (default: `8`)
327
+ - `--amp`: Use Automatic Mixed Precision (default `false`)
328
+ - `--gradient_clip`: Clipping of the gradient norms (default: `None`)
329
+ - `--accumulate_grad_batches`: Gradient accumulation (default: `1`)
330
+ - `--ckpt_interval`: Save a checkpoint every N epochs (default: `-1`)
331
+ - `--eval_interval`: Do an evaluation round every N epochs (default: `1`)
332
+ - `--silent`: Minimize stdout output (default: `false`)
333
+
334
+ **Paths**
335
+
336
+ - `--data_dir`: Directory where the data is located or should be downloaded (default: `./data`)
337
+ - `--log_dir`: Directory where the results logs should be saved (default: `/results`)
338
+ - `--save_ckpt_path`: File where the checkpoint should be saved (default: `None`)
339
+ - `--load_ckpt_path`: File of the checkpoint to be loaded (default: `None`)
340
+
341
+ **Optimizer**
342
+
343
+ - `--optimizer`: Optimizer to use (default: `adam`)
344
+ - `--learning_rate`: Learning rate to use (default: `0.002` for single-GPU)
345
+ - `--momentum`: Momentum to use (default: `0.9`)
346
+ - `--weight_decay`: Weight decay to use (default: `0.1`)
347
+
348
+ **QM9 dataset**
349
+
350
+ - `--task`: Regression task to train on (default: `homo`)
351
+ - `--precompute_bases`: Precompute bases at the beginning of the script during dataset initialization, instead of computing them at the beginning of each forward pass (default: `false`)
352
+
353
+ **Model architecture**
354
+
355
+ - `--num_layers`: Number of stacked Transformer layers (default: `7`)
356
+ - `--num_heads`: Number of heads in self-attention (default: `8`)
357
+ - `--channels_div`: Channels division before feeding to attention layer (default: `2`)
358
+ - `--pooling`: Type of graph pooling (default: `max`)
359
+ - `--norm`: Apply a normalization layer after each attention block (default: `false`)
360
+ - `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`)
361
+ - `--low_memory`: If true, will use fused ops that are slower but use less memory (expect 25 percent less memory). Only has an effect if AMP is enabled on NVIDIA Volta GPUs or if running on Ampere GPUs (default: `false`)
362
+ - `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`)
363
+ - `--num_channels`: Number of channels for the hidden features (default: `32`)
364
+
365
+
366
+ ### Command-line options
367
+
368
+ To show the full list of available options and their descriptions, use the `-h` or `--help` command-line option, for example: `python -m se3_transformer.runtime.training --help`.
369
+
370
+
371
+ ### Dataset guidelines
372
+
373
+ #### Demo dataset
374
+
375
+ The SE(3)-Transformer was trained on the QM9 dataset.
376
+
377
+ The QM9 dataset is hosted on DGL servers and downloaded (38MB) automatically when needed. By default, it is stored in the `./data` directory, but this location can be changed with the `--data_dir` argument.
378
+
379
+ The dataset is saved as a `qm9_edge.npz` file and converted to DGL graphs at runtime.
380
+
381
+ As input features, we use:
382
+ - Node features (6D):
383
+ - One-hot-encoded atom type (5D) (atom types: H, C, N, O, F)
384
+ - Number of protons of each atom (1D)
385
+ - Edge features: one-hot-encoded bond type (4D) (bond types: single, double, triple, aromatic)
386
+ - The relative positions between adjacent nodes (atoms)
387
+
388
+ #### Custom datasets
389
+
390
+ To use this network on a new dataset, you can extend the `DataModule` class present in `se3_transformer/data_loading/data_module.py`.
391
+
392
+ Your custom collate function should return a tuple with:
393
+
394
+ - A (batched) DGLGraph object
395
+ - A dictionary of node features ({‘{degree}’: tensor})
396
+ - A dictionary of edge features ({‘{degree}’: tensor})
397
+ - (Optional) Precomputed bases as a dictionary
398
+ - Labels as a tensor
399
+
400
+ You can then modify the `training.py` and `inference.py` scripts to use your new data module.
401
+
402
+ ### Training process
403
+
404
+ The training script is `se3_transformer/runtime/training.py`, to be run as a module: `python -m se3_transformer.runtime.training`.
405
+
406
+ **Logs**
407
+
408
+ By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`.
409
+
410
+ You can connect your existing Weights & Biases account by setting the `WANDB_API_KEY` environment variable.
411
+
412
+ **Checkpoints**
413
+
414
+ The argument `--save_ckpt_path` can be set to the path of the file where the checkpoints should be saved.
415
+ `--ckpt_interval` can also be set to the interval (in the number of epochs) between checkpoints.
416
+
417
+ **Evaluation**
418
+
419
+ The evaluation metric is the Mean Absolute Error (MAE).
420
+
421
+ `--eval_interval` can be set to the interval (in the number of epochs) between evaluation rounds. By default, an evaluation round is performed after each epoch.
422
+
423
+ **Automatic Mixed Precision**
424
+
425
+ To enable Mixed Precision training, add the `--amp` flag.
426
+
427
+ **Multi-GPU and multi-node**
428
+
429
+ The training script supports the PyTorch elastic launcher to run on multiple GPUs or nodes. Refer to the [official documentation](https://pytorch.org/docs/1.9.0/elastic/run.html).
430
+
431
+ For example, to train on all available GPUs with AMP:
432
+
433
+ ```
434
+ python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --module se3_transformer.runtime.training --amp
435
+ ```
436
+
437
+
438
+ ### Inference process
439
+
440
+ Inference can be run by using the `se3_transformer.runtime.inference` python module.
441
+
442
+ The inference script is `se3_transformer/runtime/inference.py`, to be run as a module: `python -m se3_transformer.runtime.inference`. It requires a pre-trained model checkpoint (to be passed as `--load_ckpt_path`).
443
+
444
+
445
+ ## Performance
446
+
447
+ The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
448
+
449
+ ### Benchmarking
450
+
451
+ The following section shows how to run benchmarks measuring the model performance in training and inference modes.
452
+
453
+ #### Training performance benchmark
454
+
455
+ To benchmark the training performance on a specific batch size, run `bash scripts/benchmarck_train.sh {BATCH_SIZE}` for single GPU, and `bash scripts/benchmarck_train_multi_gpu.sh {BATCH_SIZE}` for multi-GPU.
456
+
457
+ #### Inference performance benchmark
458
+
459
+ To benchmark the inference performance on a specific batch size, run `bash scripts/benchmarck_inference.sh {BATCH_SIZE}`.
460
+
461
+ ### Results
462
+
463
+
464
+ The following sections provide details on how we achieved our performance and accuracy in training and inference.
465
+
466
+ #### Training accuracy results
467
+
468
+ ##### Training accuracy: NVIDIA DGX A100 (8x A100 80GB)
469
+
470
+ Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 (8x A100 80GB) GPUs.
471
+
472
+ | GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) |
473
+ |:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
474
+ | 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x |
475
+ | 8 | 240 | 0.03417 | 0.03424 | 15min | 12min | 1.25x |
476
+
477
+
478
+ ##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
479
+
480
+ Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
481
+
482
+ | GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) |
483
+ |:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
484
+ | 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x |
485
+ | 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
486
+
487
+
488
+ #### Training performance results
489
+
490
+ ##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
491
+
492
+ Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 8x A100 80GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
493
+
494
+ | GPUs | Batch size / GPU | Throughput - TF32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (mixed precision - TF32) | Weak scaling - TF32 | Weak scaling - mixed precision |
495
+ |:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
496
+ | 1 | 240 | 2.21 | 2.92 | 1.32x | | |
497
+ | 1 | 120 | 1.81 | 2.04 | 1.13x | | |
498
+ | 8 | 240 | 17.15 | 22.95 | 1.34x | 7.76 | 7.86 |
499
+ | 8 | 120 | 13.89 | 15.62 | 1.12x | 7.67 | 7.66 |
500
+
501
+
502
+ To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
503
+
504
+
505
+ ##### Training performance: NVIDIA DGX-1 (8x V100 16GB)
506
+
507
+ Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 8x V100 16GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
508
+
509
+ | GPUs | Batch size / GPU | Throughput - FP32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (FP32 - mixed precision) | Weak scaling - FP32 | Weak scaling - mixed precision |
510
+ |:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
511
+ | 1 | 240 | 1.25 | 1.88 | 1.50x | | |
512
+ | 1 | 120 | 1.03 | 1.41 | 1.37x | | |
513
+ | 8 | 240 | 9.33 | 14.02 | 1.50x | 7.46 | 7.46 |
514
+ | 8 | 120 | 7.39 | 9.41 | 1.27x | 7.17 | 6.67 |
515
+
516
+
517
+ To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
518
+
519
+
520
+ #### Inference performance results
521
+
522
+
523
+ ##### Inference performance: NVIDIA DGX A100 (1x A100 80GB)
524
+
525
+ Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 1x A100 80GB GPU.
526
+
527
+ FP16
528
+
529
+ | Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
530
+ |:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
531
+ | 1600 | 11.60 | 140.94 | 138.29 | 140.12 | 386.40 |
532
+ | 800 | 10.74 | 75.69 | 75.74 | 76.50 | 79.77 |
533
+ | 400 | 8.86 | 45.57 | 46.11 | 46.60 | 49.97 |
534
+
535
+ TF32
536
+
537
+ | Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
538
+ |:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
539
+ | 1600 | 8.58 | 189.20 | 186.39 | 187.71 | 420.28 |
540
+ | 800 | 8.28 | 97.56 | 97.20 | 97.73 | 101.13 |
541
+ | 400 | 7.55 | 53.38 | 53.72 | 54.48 | 56.62 |
542
+
543
+ To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
544
+
545
+
546
+
547
+ ##### Inference performance: NVIDIA DGX-1 (1x V100 16GB)
548
+
549
+ Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 1x V100 16GB GPU.
550
+
551
+ FP16
552
+
553
+ | Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
554
+ |:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
555
+ | 1600 | 6.42 | 254.54 | 247.97 | 249.29 | 721.15 |
556
+ | 800 | 6.13 | 132.07 | 131.90 | 132.70 | 140.15 |
557
+ | 400 | 5.37 | 75.12 | 76.01 | 76.66 | 79.90 |
558
+
559
+ FP32
560
+
561
+ | Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
562
+ |:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
563
+ | 1600 | 3.39 | 475.86 | 473.82 | 475.64 | 891.18 |
564
+ | 800 | 3.36 | 239.17 | 240.64 | 241.65 | 243.70 |
565
+ | 400 | 3.17 | 126.67 | 128.19 | 128.82 | 130.54 |
566
+
567
+
568
+ To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
569
+
570
+
571
+ ## Release notes
572
+
573
+ ### Changelog
574
+
575
+ August 2021
576
+ - Initial release
577
+
578
+ ### Known issues
579
+
580
+ If you encounter `OSError: [Errno 12] Cannot allocate memory` during the Dataloader iterator creation (more precisely during the `fork()`, this is most likely due to the use of the `--precompute_bases` flag. If you cannot add more RAM or Swap to your machine, it is recommended to turn off bases precomputation by removing the `--precompute_bases` flag or using `--precompute_bases false`.
env/SE3Transformer/build/lib/se3_transformer/__init__.py ADDED
File without changes
env/SE3Transformer/build/lib/se3_transformer/data_loading/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .qm9 import QM9DataModule
env/SE3Transformer/build/lib/se3_transformer/data_loading/data_module.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import torch.distributed as dist
25
+ from abc import ABC
26
+ from torch.utils.data import DataLoader, DistributedSampler, Dataset
27
+
28
+ from se3_transformer.runtime.utils import get_local_rank
29
+
30
+
31
+ def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader:
32
+ # Classic or distributed dataloader depending on the context
33
+ sampler = DistributedSampler(dataset, shuffle=shuffle) if dist.is_initialized() else None
34
+ return DataLoader(dataset, shuffle=(shuffle and sampler is None), sampler=sampler, **kwargs)
35
+
36
+
37
+ class DataModule(ABC):
38
+ """ Abstract DataModule. Children must define self.ds_{train | val | test}. """
39
+
40
+ def __init__(self, **dataloader_kwargs):
41
+ super().__init__()
42
+ if get_local_rank() == 0:
43
+ self.prepare_data()
44
+
45
+ # Wait until rank zero has prepared the data (download, preprocessing, ...)
46
+ if dist.is_initialized():
47
+ dist.barrier(device_ids=[get_local_rank()])
48
+
49
+ self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
50
+ self.ds_train, self.ds_val, self.ds_test = None, None, None
51
+
52
+ def prepare_data(self):
53
+ """ Method called only once per node. Put here any downloading or preprocessing """
54
+ pass
55
+
56
+ def train_dataloader(self) -> DataLoader:
57
+ return _get_dataloader(self.ds_train, shuffle=True, **self.dataloader_kwargs)
58
+
59
+ def val_dataloader(self) -> DataLoader:
60
+ return _get_dataloader(self.ds_val, shuffle=False, **self.dataloader_kwargs)
61
+
62
+ def test_dataloader(self) -> DataLoader:
63
+ return _get_dataloader(self.ds_test, shuffle=False, **self.dataloader_kwargs)
env/SE3Transformer/build/lib/se3_transformer/data_loading/qm9.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+ from typing import Tuple
24
+
25
+ import dgl
26
+ import pathlib
27
+ import torch
28
+ from dgl.data import QM9EdgeDataset
29
+ from dgl import DGLGraph
30
+ from torch import Tensor
31
+ from torch.utils.data import random_split, DataLoader, Dataset
32
+ from tqdm import tqdm
33
+
34
+ from se3_transformer.data_loading.data_module import DataModule
35
+ from se3_transformer.model.basis import get_basis
36
+ from se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores
37
+
38
+
39
+ def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor:
40
+ x = qm9_graph.ndata['pos']
41
+ src, dst = qm9_graph.edges()
42
+ rel_pos = x[dst] - x[src]
43
+ return rel_pos
44
+
45
+
46
+ def _get_split_sizes(full_dataset: Dataset) -> Tuple[int, int, int]:
47
+ len_full = len(full_dataset)
48
+ len_train = 100_000
49
+ len_test = int(0.1 * len_full)
50
+ len_val = len_full - len_train - len_test
51
+ return len_train, len_val, len_test
52
+
53
+
54
+ class QM9DataModule(DataModule):
55
+ """
56
+ Datamodule wrapping https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset
57
+ Training set is 100k molecules. Test set is 10% of the dataset. Validation set is the rest.
58
+ This includes all the molecules from QM9 except the ones that are uncharacterized.
59
+ """
60
+
61
+ NODE_FEATURE_DIM = 6
62
+ EDGE_FEATURE_DIM = 4
63
+
64
+ def __init__(self,
65
+ data_dir: pathlib.Path,
66
+ task: str = 'homo',
67
+ batch_size: int = 240,
68
+ num_workers: int = 8,
69
+ num_degrees: int = 4,
70
+ amp: bool = False,
71
+ precompute_bases: bool = False,
72
+ **kwargs):
73
+ self.data_dir = data_dir # This needs to be before __init__ so that prepare_data has access to it
74
+ super().__init__(batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate)
75
+ self.amp = amp
76
+ self.task = task
77
+ self.batch_size = batch_size
78
+ self.num_degrees = num_degrees
79
+
80
+ qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
81
+ if precompute_bases:
82
+ bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
83
+ full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size,
84
+ num_workers=num_workers, **qm9_kwargs)
85
+ else:
86
+ full_dataset = QM9EdgeDataset(**qm9_kwargs)
87
+
88
+ self.ds_train, self.ds_val, self.ds_test = random_split(full_dataset, _get_split_sizes(full_dataset),
89
+ generator=torch.Generator().manual_seed(0))
90
+
91
+ train_targets = full_dataset.targets[self.ds_train.indices, full_dataset.label_keys[0]]
92
+ self.targets_mean = train_targets.mean()
93
+ self.targets_std = train_targets.std()
94
+
95
+ def prepare_data(self):
96
+ # Download the QM9 preprocessed data
97
+ QM9EdgeDataset(verbose=True, raw_dir=str(self.data_dir))
98
+
99
+ def _collate(self, samples):
100
+ graphs, y, *bases = map(list, zip(*samples))
101
+ batched_graph = dgl.batch(graphs)
102
+ edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]}
103
+ batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
104
+ # get node features
105
+ node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]}
106
+ targets = (torch.cat(y) - self.targets_mean) / self.targets_std
107
+
108
+ if bases:
109
+ # collate bases
110
+ all_bases = {
111
+ key: torch.cat([b[key] for b in bases[0]], dim=0)
112
+ for key in bases[0][0].keys()
113
+ }
114
+
115
+ return batched_graph, node_feats, edge_feats, all_bases, targets
116
+ else:
117
+ return batched_graph, node_feats, edge_feats, targets
118
+
119
+ @staticmethod
120
+ def add_argparse_args(parent_parser):
121
+ parser = parent_parser.add_argument_group("QM9 dataset")
122
+ parser.add_argument('--task', type=str, default='homo', const='homo', nargs='?',
123
+ choices=['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
124
+ 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'],
125
+ help='Regression task to train on')
126
+ parser.add_argument('--precompute_bases', type=str2bool, nargs='?', const=True, default=False,
127
+ help='Precompute bases at the beginning of the script during dataset initialization,'
128
+ ' instead of computing them at the beginning of each forward pass.')
129
+ return parent_parser
130
+
131
+ def __repr__(self):
132
+ return f'QM9({self.task})'
133
+
134
+
135
+ class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
136
+ """ Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
137
+
138
+ def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs):
139
+ """
140
+ :param bases_kwargs: Arguments to feed the bases computation function
141
+ :param batch_size: Batch size to use when iterating over the dataset for computing bases
142
+ """
143
+ self.bases_kwargs = bases_kwargs
144
+ self.batch_size = batch_size
145
+ self.bases = None
146
+ self.num_workers = num_workers
147
+ super().__init__(*args, **kwargs)
148
+
149
+ def load(self):
150
+ super().load()
151
+ # Iterate through the dataset and compute bases (pairwise only)
152
+ # Potential improvement: use multi-GPU and gather
153
+ dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
154
+ collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
155
+ bases = []
156
+ for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',
157
+ disable=get_local_rank() != 0):
158
+ rel_pos = _get_relative_pos(graph)
159
+ # Compute the bases with the GPU but convert the result to CPU to store in RAM
160
+ bases.append({k: v.cpu() for k, v in get_basis(rel_pos.cuda(), **self.bases_kwargs).items()})
161
+ self.bases = bases # Assign at the end so that __getitem__ isn't confused
162
+
163
+ def __getitem__(self, idx: int):
164
+ graph, label = super().__getitem__(idx)
165
+
166
+ if self.bases:
167
+ bases_idx = idx // self.batch_size
168
+ bases_cumsum_idx = self.ne_cumsum[idx] - self.ne_cumsum[bases_idx * self.batch_size]
169
+ bases_cumsum_next_idx = self.ne_cumsum[idx + 1] - self.ne_cumsum[bases_idx * self.batch_size]
170
+ return graph, label, {key: basis[bases_cumsum_idx:bases_cumsum_next_idx] for key, basis in
171
+ self.bases[bases_idx].items()}
172
+ else:
173
+ return graph, label
env/SE3Transformer/build/lib/se3_transformer/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .transformer import SE3Transformer, SE3TransformerPooled
2
+ from .fiber import Fiber
env/SE3Transformer/build/lib/se3_transformer/model/basis.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+
25
+ from functools import lru_cache
26
+ from typing import Dict, List
27
+
28
+ import e3nn.o3 as o3
29
+ import torch
30
+ import torch.nn.functional as F
31
+ from torch import Tensor
32
+ from torch.cuda.nvtx import range as nvtx_range
33
+
34
+ from se3_transformer.runtime.utils import degree_to_dim
35
+
36
+
37
+ @lru_cache(maxsize=None)
38
+ def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor:
39
+ """ Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """
40
+ return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64, device=device).permute(2, 1, 0)
41
+
42
+
43
+ @lru_cache(maxsize=None)
44
+ def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
45
+ all_cb = []
46
+ for d_in in range(max_degree + 1):
47
+ for d_out in range(max_degree + 1):
48
+ K_Js = []
49
+ for J in range(abs(d_in - d_out), d_in + d_out + 1):
50
+ K_Js.append(get_clebsch_gordon(J, d_in, d_out, device))
51
+ all_cb.append(K_Js)
52
+ return all_cb
53
+
54
+
55
+ def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
56
+ all_degrees = list(range(2 * max_degree + 1))
57
+ with nvtx_range('spherical harmonics'):
58
+ sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
59
+ return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
60
+
61
+
62
+ @torch.jit.script
63
+ def get_basis_script(max_degree: int,
64
+ use_pad_trick: bool,
65
+ spherical_harmonics: List[Tensor],
66
+ clebsch_gordon: List[List[Tensor]],
67
+ amp: bool) -> Dict[str, Tensor]:
68
+ """
69
+ Compute pairwise bases matrices for degrees up to max_degree
70
+ :param max_degree: Maximum input or output degree
71
+ :param use_pad_trick: Pad some of the odd dimensions for a better use of Tensor Cores
72
+ :param spherical_harmonics: List of computed spherical harmonics
73
+ :param clebsch_gordon: List of computed CB-coefficients
74
+ :param amp: When true, return bases in FP16 precision
75
+ """
76
+ basis = {}
77
+ idx = 0
78
+ # Double for loop instead of product() because of JIT script
79
+ for d_in in range(max_degree + 1):
80
+ for d_out in range(max_degree + 1):
81
+ key = f'{d_in},{d_out}'
82
+ K_Js = []
83
+ for freq_idx, J in enumerate(range(abs(d_in - d_out), d_in + d_out + 1)):
84
+ Q_J = clebsch_gordon[idx][freq_idx]
85
+ K_Js.append(torch.einsum('n f, k l f -> n l k', spherical_harmonics[J].float(), Q_J.float()))
86
+
87
+ basis[key] = torch.stack(K_Js, 2) # Stack on second dim so order is n l f k
88
+ if amp:
89
+ basis[key] = basis[key].half()
90
+ if use_pad_trick:
91
+ basis[key] = F.pad(basis[key], (0, 1)) # Pad the k dimension, that can be sliced later
92
+
93
+ idx += 1
94
+
95
+ return basis
96
+
97
+
98
+ @torch.jit.script
99
+ def update_basis_with_fused(basis: Dict[str, Tensor],
100
+ max_degree: int,
101
+ use_pad_trick: bool,
102
+ fully_fused: bool) -> Dict[str, Tensor]:
103
+ """ Update the basis dict with partially and optionally fully fused bases """
104
+ num_edges = basis['0,0'].shape[0]
105
+ device = basis['0,0'].device
106
+ dtype = basis['0,0'].dtype
107
+ sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)])
108
+
109
+ # Fused per output degree
110
+ for d_out in range(max_degree + 1):
111
+ sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)])
112
+ basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick),
113
+ device=device, dtype=dtype)
114
+ acc_d, acc_f = 0, 0
115
+ for d_in in range(max_degree + 1):
116
+ basis_fused[:, acc_d:acc_d + degree_to_dim(d_in), acc_f:acc_f + degree_to_dim(min(d_out, d_in)),
117
+ :degree_to_dim(d_out)] = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
118
+
119
+ acc_d += degree_to_dim(d_in)
120
+ acc_f += degree_to_dim(min(d_out, d_in))
121
+
122
+ basis[f'out{d_out}_fused'] = basis_fused
123
+
124
+ # Fused per input degree
125
+ for d_in in range(max_degree + 1):
126
+ sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)])
127
+ basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim,
128
+ device=device, dtype=dtype)
129
+ acc_d, acc_f = 0, 0
130
+ for d_out in range(max_degree + 1):
131
+ basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \
132
+ = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
133
+
134
+ acc_d += degree_to_dim(d_out)
135
+ acc_f += degree_to_dim(min(d_out, d_in))
136
+
137
+ basis[f'in{d_in}_fused'] = basis_fused
138
+
139
+ if fully_fused:
140
+ # Fully fused
141
+ # Double sum this way because of JIT script
142
+ sum_freq = sum([
143
+ sum([degree_to_dim(min(d_in, d_out)) for d_in in range(max_degree + 1)]) for d_out in range(max_degree + 1)
144
+ ])
145
+ basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, sum_dim, device=device, dtype=dtype)
146
+
147
+ acc_d, acc_f = 0, 0
148
+ for d_out in range(max_degree + 1):
149
+ b = basis[f'out{d_out}_fused']
150
+ basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d + degree_to_dim(d_out)] = b[:, :, :,
151
+ :degree_to_dim(d_out)]
152
+ acc_f += b.shape[2]
153
+ acc_d += degree_to_dim(d_out)
154
+
155
+ basis['fully_fused'] = basis_fused
156
+
157
+ del basis['0,0'] # We know that the basis for l = k = 0 is filled with a constant
158
+ return basis
159
+
160
+
161
+ def get_basis(relative_pos: Tensor,
162
+ max_degree: int = 4,
163
+ compute_gradients: bool = False,
164
+ use_pad_trick: bool = False,
165
+ amp: bool = False) -> Dict[str, Tensor]:
166
+ with nvtx_range('spherical harmonics'):
167
+ spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
168
+ with nvtx_range('CB coefficients'):
169
+ clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device)
170
+
171
+ with torch.autograd.set_grad_enabled(compute_gradients):
172
+ with nvtx_range('bases'):
173
+ basis = get_basis_script(max_degree=max_degree,
174
+ use_pad_trick=use_pad_trick,
175
+ spherical_harmonics=spherical_harmonics,
176
+ clebsch_gordon=clebsch_gordon,
177
+ amp=amp)
178
+ return basis
env/SE3Transformer/build/lib/se3_transformer/model/fiber.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+
25
+ from collections import namedtuple
26
+ from itertools import product
27
+ from typing import Dict
28
+
29
+ import torch
30
+ from torch import Tensor
31
+
32
+ from se3_transformer.runtime.utils import degree_to_dim
33
+
34
+ FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
35
+
36
+
37
+ class Fiber(dict):
38
+ """
39
+ Describes the structure of some set of features.
40
+ Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
41
+ Type-0 features: invariant scalars
42
+ Type-1 features: equivariant 3D vectors
43
+ Type-2 features: equivariant symmetric traceless matrices
44
+ ...
45
+
46
+ As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
47
+ The 'multiplicity' or 'number of channels' is the number of features of a given type.
48
+ This class puts together all the degrees and their multiplicities in order to describe
49
+ the inputs, outputs or hidden features of SE3 layers.
50
+ """
51
+
52
+ def __init__(self, structure):
53
+ if isinstance(structure, dict):
54
+ structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
55
+ elif not isinstance(structure[0], FiberEl):
56
+ structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
57
+ self.structure = structure
58
+ super().__init__({d: m for d, m in self.structure})
59
+
60
+ @property
61
+ def degrees(self):
62
+ return sorted([t.degree for t in self.structure])
63
+
64
+ @property
65
+ def channels(self):
66
+ return [self[d] for d in self.degrees]
67
+
68
+ @property
69
+ def num_features(self):
70
+ """ Size of the resulting tensor if all features were concatenated together """
71
+ return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
72
+
73
+ @staticmethod
74
+ def create(num_degrees: int, num_channels: int):
75
+ """ Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
76
+ return Fiber([(degree, num_channels) for degree in range(num_degrees)])
77
+
78
+ @staticmethod
79
+ def from_features(feats: Dict[str, Tensor]):
80
+ """ Infer the Fiber structure from a feature dict """
81
+ structure = {}
82
+ for k, v in feats.items():
83
+ degree = int(k)
84
+ assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
85
+ assert v.shape[-1] == degree_to_dim(degree)
86
+ structure[degree] = v.shape[-2]
87
+ return Fiber(structure)
88
+
89
+ def __getitem__(self, degree: int):
90
+ """ fiber[degree] returns the multiplicity for this degree """
91
+ return dict(self.structure).get(degree, 0)
92
+
93
+ def __iter__(self):
94
+ """ Iterate over namedtuples (degree, channels) """
95
+ return iter(self.structure)
96
+
97
+ def __mul__(self, other):
98
+ """
99
+ If other in an int, multiplies all the multiplicities by other.
100
+ If other is a fiber, returns the cartesian product.
101
+ """
102
+ if isinstance(other, Fiber):
103
+ return product(self.structure, other.structure)
104
+ elif isinstance(other, int):
105
+ return Fiber({t.degree: t.channels * other for t in self.structure})
106
+
107
+ def __add__(self, other):
108
+ """
109
+ If other in an int, add other to all the multiplicities.
110
+ If other is a fiber, add the multiplicities of the fibers together.
111
+ """
112
+ if isinstance(other, Fiber):
113
+ return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
114
+ elif isinstance(other, int):
115
+ return Fiber({t.degree: t.channels + other for t in self.structure})
116
+
117
+ def __repr__(self):
118
+ return str(self.structure)
119
+
120
+ @staticmethod
121
+ def combine_max(f1, f2):
122
+ """ Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
123
+ new_dict = dict(f1.structure)
124
+ for k, m in f2.structure:
125
+ new_dict[k] = max(new_dict.get(k, 0), m)
126
+
127
+ return Fiber(list(new_dict.items()))
128
+
129
+ @staticmethod
130
+ def combine_selectively(f1, f2):
131
+ """ Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
132
+ # only use orders which occur in fiber f1
133
+ new_dict = dict(f1.structure)
134
+ for k in f1.degrees:
135
+ if k in f2.degrees:
136
+ new_dict[k] += f2[k]
137
+ return Fiber(list(new_dict.items()))
138
+
139
+ def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
140
+ # dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
141
+ fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
142
+ self.degrees]
143
+ fibers = torch.cat(fibers, -1)
144
+ return fibers
env/SE3Transformer/build/lib/se3_transformer/model/layers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .linear import LinearSE3
2
+ from .norm import NormSE3
3
+ from .pooling import GPooling
4
+ from .convolution import ConvSE3
5
+ from .attention import AttentionBlockSE3
env/SE3Transformer/build/lib/se3_transformer/model/layers/attention.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import dgl
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ from dgl import DGLGraph
29
+ from dgl.ops import edge_softmax
30
+ from torch import Tensor
31
+ from typing import Dict, Optional, Union
32
+
33
+ from se3_transformer.model.fiber import Fiber
34
+ from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
35
+ from se3_transformer.model.layers.linear import LinearSE3
36
+ from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
37
+ from torch.cuda.nvtx import range as nvtx_range
38
+
39
+
40
+ class AttentionSE3(nn.Module):
41
+ """ Multi-headed sparse graph self-attention (SE(3)-equivariant) """
42
+
43
+ def __init__(
44
+ self,
45
+ num_heads: int,
46
+ key_fiber: Fiber,
47
+ value_fiber: Fiber
48
+ ):
49
+ """
50
+ :param num_heads: Number of attention heads
51
+ :param key_fiber: Fiber for the keys (and also for the queries)
52
+ :param value_fiber: Fiber for the values
53
+ """
54
+ super().__init__()
55
+ self.num_heads = num_heads
56
+ self.key_fiber = key_fiber
57
+ self.value_fiber = value_fiber
58
+
59
+ def forward(
60
+ self,
61
+ value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
62
+ key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
63
+ query: Dict[str, Tensor], # node features
64
+ graph: DGLGraph
65
+ ):
66
+ with nvtx_range('AttentionSE3'):
67
+ with nvtx_range('reshape keys and queries'):
68
+ if isinstance(key, Tensor):
69
+ # case where features of all types are fused
70
+ key = key.reshape(key.shape[0], self.num_heads, -1)
71
+ # need to reshape queries that way to keep the same layout as keys
72
+ out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
73
+ query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
74
+ else:
75
+ # features are not fused, need to fuse and reshape them
76
+ key = self.key_fiber.to_attention_heads(key, self.num_heads)
77
+ query = self.key_fiber.to_attention_heads(query, self.num_heads)
78
+
79
+ with nvtx_range('attention dot product + softmax'):
80
+ # Compute attention weights (softmax of inner product between key and query)
81
+ edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
82
+ edge_weights /= np.sqrt(self.key_fiber.num_features)
83
+ edge_weights = edge_softmax(graph, edge_weights)
84
+ edge_weights = edge_weights[..., None, None]
85
+
86
+ with nvtx_range('weighted sum'):
87
+ if isinstance(value, Tensor):
88
+ # features of all types are fused
89
+ v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
90
+ weights = edge_weights * v
91
+ feat_out = dgl.ops.copy_e_sum(graph, weights)
92
+ feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
93
+ out = unfuse_features(feat_out, self.value_fiber.degrees)
94
+ else:
95
+ out = {}
96
+ for degree, channels in self.value_fiber:
97
+ v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
98
+ degree_to_dim(degree))
99
+ weights = edge_weights * v
100
+ res = dgl.ops.copy_e_sum(graph, weights)
101
+ out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
102
+
103
+ return out
104
+
105
+
106
+ class AttentionBlockSE3(nn.Module):
107
+ """ Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
108
+
109
+ def __init__(
110
+ self,
111
+ fiber_in: Fiber,
112
+ fiber_out: Fiber,
113
+ fiber_edge: Optional[Fiber] = None,
114
+ num_heads: int = 4,
115
+ channels_div: int = 2,
116
+ use_layer_norm: bool = False,
117
+ max_degree: bool = 4,
118
+ fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
119
+ **kwargs
120
+ ):
121
+ """
122
+ :param fiber_in: Fiber describing the input features
123
+ :param fiber_out: Fiber describing the output features
124
+ :param fiber_edge: Fiber describing the edge features (node distances excluded)
125
+ :param num_heads: Number of attention heads
126
+ :param channels_div: Divide the channels by this integer for computing values
127
+ :param use_layer_norm: Apply layer normalization between MLP layers
128
+ :param max_degree: Maximum degree used in the bases computation
129
+ :param fuse_level: Maximum fuse level to use in TFN convolutions
130
+ """
131
+ super().__init__()
132
+ if fiber_edge is None:
133
+ fiber_edge = Fiber({})
134
+ self.fiber_in = fiber_in
135
+ # value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
136
+ value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
137
+ # key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
138
+ # (queries are merely projected, hence degrees have to match input)
139
+ key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
140
+
141
+ self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
142
+ use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
143
+ allow_fused_output=True)
144
+ self.to_query = LinearSE3(fiber_in, key_query_fiber)
145
+ self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
146
+ self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
147
+
148
+ def forward(
149
+ self,
150
+ node_features: Dict[str, Tensor],
151
+ edge_features: Dict[str, Tensor],
152
+ graph: DGLGraph,
153
+ basis: Dict[str, Tensor]
154
+ ):
155
+ with nvtx_range('AttentionBlockSE3'):
156
+ with nvtx_range('keys / values'):
157
+ fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
158
+ key, value = self._get_key_value_from_fused(fused_key_value)
159
+
160
+ with nvtx_range('queries'):
161
+ query = self.to_query(node_features)
162
+
163
+ z = self.attention(value, key, query, graph)
164
+ z_concat = aggregate_residual(node_features, z, 'cat')
165
+ return self.project(z_concat)
166
+
167
+ def _get_key_value_from_fused(self, fused_key_value):
168
+ # Extract keys and queries features from fused features
169
+ if isinstance(fused_key_value, Tensor):
170
+ # Previous layer was a fully fused convolution
171
+ value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
172
+ else:
173
+ key, value = {}, {}
174
+ for degree, feat in fused_key_value.items():
175
+ if int(degree) in self.fiber_in.degrees:
176
+ value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
177
+ else:
178
+ value[degree] = feat
179
+
180
+ return key, value
env/SE3Transformer/build/lib/se3_transformer/model/layers/convolution.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ from enum import Enum
25
+ from itertools import product
26
+ from typing import Dict
27
+
28
+ import dgl
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+ from dgl import DGLGraph
33
+ from torch import Tensor
34
+ from torch.cuda.nvtx import range as nvtx_range
35
+
36
+ from se3_transformer.model.fiber import Fiber
37
+ from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
38
+
39
+
40
+ class ConvSE3FuseLevel(Enum):
41
+ """
42
+ Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
43
+ If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
44
+ A higher level means faster training, but also more memory usage.
45
+ If you are tight on memory and want to feed large inputs to the network, choose a low value.
46
+ If you want to train fast, choose a high value.
47
+ Recommended value is FULL with AMP.
48
+
49
+ Fully fused TFN convolutions requirements:
50
+ - all input channels are the same
51
+ - all output channels are the same
52
+ - input degrees span the range [0, ..., max_degree]
53
+ - output degrees span the range [0, ..., max_degree]
54
+
55
+ Partially fused TFN convolutions requirements:
56
+ * For fusing by output degree:
57
+ - all input channels are the same
58
+ - input degrees span the range [0, ..., max_degree]
59
+ * For fusing by input degree:
60
+ - all output channels are the same
61
+ - output degrees span the range [0, ..., max_degree]
62
+
63
+ Original TFN pairwise convolutions: no requirements
64
+ """
65
+
66
+ FULL = 2
67
+ PARTIAL = 1
68
+ NONE = 0
69
+
70
+
71
+ class RadialProfile(nn.Module):
72
+ """
73
+ Radial profile function.
74
+ Outputs weights used to weigh basis matrices in order to get convolution kernels.
75
+ In TFN notation: $R^{l,k}$
76
+ In SE(3)-Transformer notation: $\phi^{l,k}$
77
+
78
+ Note:
79
+ In the original papers, this function only depends on relative node distances ||x||.
80
+ Here, we allow this function to also take as input additional invariant edge features.
81
+ This does not break equivariance and adds expressive power to the model.
82
+
83
+ Diagram:
84
+ invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ num_freq: int,
90
+ channels_in: int,
91
+ channels_out: int,
92
+ edge_dim: int = 1,
93
+ mid_dim: int = 32,
94
+ use_layer_norm: bool = False
95
+ ):
96
+ """
97
+ :param num_freq: Number of frequencies
98
+ :param channels_in: Number of input channels
99
+ :param channels_out: Number of output channels
100
+ :param edge_dim: Number of invariant edge features (input to the radial function)
101
+ :param mid_dim: Size of the hidden MLP layers
102
+ :param use_layer_norm: Apply layer normalization between MLP layers
103
+ """
104
+ super().__init__()
105
+ modules = [
106
+ nn.Linear(edge_dim, mid_dim),
107
+ nn.LayerNorm(mid_dim) if use_layer_norm else None,
108
+ nn.ReLU(),
109
+ nn.Linear(mid_dim, mid_dim),
110
+ nn.LayerNorm(mid_dim) if use_layer_norm else None,
111
+ nn.ReLU(),
112
+ nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
113
+ ]
114
+
115
+ self.net = nn.Sequential(*[m for m in modules if m is not None])
116
+
117
+ def forward(self, features: Tensor) -> Tensor:
118
+ return self.net(features)
119
+
120
+
121
+ class VersatileConvSE3(nn.Module):
122
+ """
123
+ Building block for TFN convolutions.
124
+ This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
125
+ """
126
+
127
+ def __init__(self,
128
+ freq_sum: int,
129
+ channels_in: int,
130
+ channels_out: int,
131
+ edge_dim: int,
132
+ use_layer_norm: bool,
133
+ fuse_level: ConvSE3FuseLevel):
134
+ super().__init__()
135
+ self.freq_sum = freq_sum
136
+ self.channels_out = channels_out
137
+ self.channels_in = channels_in
138
+ self.fuse_level = fuse_level
139
+ self.radial_func = RadialProfile(num_freq=freq_sum,
140
+ channels_in=channels_in,
141
+ channels_out=channels_out,
142
+ edge_dim=edge_dim,
143
+ use_layer_norm=use_layer_norm)
144
+
145
+ def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
146
+ with nvtx_range(f'VersatileConvSE3'):
147
+ num_edges = features.shape[0]
148
+ in_dim = features.shape[2]
149
+ with nvtx_range(f'RadialProfile'):
150
+ radial_weights = self.radial_func(invariant_edge_feats) \
151
+ .view(-1, self.channels_out, self.channels_in * self.freq_sum)
152
+
153
+ if basis is not None:
154
+ # This block performs the einsum n i l, n o i f, n l f k -> n o k
155
+ out_dim = basis.shape[-1]
156
+ if self.fuse_level != ConvSE3FuseLevel.FULL:
157
+ out_dim += out_dim % 2 - 1 # Account for padded basis
158
+ basis_view = basis.view(num_edges, in_dim, -1)
159
+ tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
160
+ return (radial_weights @ tmp)[:, :, :out_dim]
161
+ else:
162
+ # k = l = 0 non-fused case
163
+ return radial_weights @ features
164
+
165
+
166
+ class ConvSE3(nn.Module):
167
+ """
168
+ SE(3)-equivariant graph convolution (Tensor Field Network convolution).
169
+ This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
170
+ Features of different degrees interact together to produce output features.
171
+
172
+ Note 1:
173
+ The option is given to not pool the output. This means that the convolution sum over neighbors will not be
174
+ done, and the returned features will be edge features instead of node features.
175
+
176
+ Note 2:
177
+ Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
178
+ Input edge features are concatenated with input source node features before the kernel is applied.
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ fiber_in: Fiber,
184
+ fiber_out: Fiber,
185
+ fiber_edge: Fiber,
186
+ pool: bool = True,
187
+ use_layer_norm: bool = False,
188
+ self_interaction: bool = False,
189
+ max_degree: int = 4,
190
+ fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
191
+ allow_fused_output: bool = False
192
+ ):
193
+ """
194
+ :param fiber_in: Fiber describing the input features
195
+ :param fiber_out: Fiber describing the output features
196
+ :param fiber_edge: Fiber describing the edge features (node distances excluded)
197
+ :param pool: If True, compute final node features by averaging incoming edge features
198
+ :param use_layer_norm: Apply layer normalization between MLP layers
199
+ :param self_interaction: Apply self-interaction of nodes
200
+ :param max_degree: Maximum degree used in the bases computation
201
+ :param fuse_level: Maximum fuse level to use in TFN convolutions
202
+ :param allow_fused_output: Allow the module to output a fused representation of features
203
+ """
204
+ super().__init__()
205
+ self.pool = pool
206
+ self.fiber_in = fiber_in
207
+ self.fiber_out = fiber_out
208
+ self.self_interaction = self_interaction
209
+ self.max_degree = max_degree
210
+ self.allow_fused_output = allow_fused_output
211
+
212
+ # channels_in: account for the concatenation of edge features
213
+ channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
214
+ channels_out_set = set([f.channels for f in self.fiber_out])
215
+ unique_channels_in = (len(channels_in_set) == 1)
216
+ unique_channels_out = (len(channels_out_set) == 1)
217
+ degrees_up_to_max = list(range(max_degree + 1))
218
+ common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
219
+
220
+ if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
221
+ unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
222
+ unique_channels_out and fiber_out.degrees == degrees_up_to_max:
223
+ # Single fused convolution
224
+ self.used_fuse_level = ConvSE3FuseLevel.FULL
225
+
226
+ sum_freq = sum([
227
+ degree_to_dim(min(d_in, d_out))
228
+ for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
229
+ ])
230
+
231
+ self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
232
+ fuse_level=self.used_fuse_level, **common_args)
233
+
234
+ elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
235
+ unique_channels_in and fiber_in.degrees == degrees_up_to_max:
236
+ # Convolutions fused per output degree
237
+ self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
238
+ self.conv_out = nn.ModuleDict()
239
+ for d_out, c_out in fiber_out:
240
+ sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
241
+ self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
242
+ fuse_level=self.used_fuse_level, **common_args)
243
+
244
+ elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
245
+ unique_channels_out and fiber_out.degrees == degrees_up_to_max:
246
+ # Convolutions fused per input degree
247
+ self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
248
+ self.conv_in = nn.ModuleDict()
249
+ for d_in, c_in in fiber_in:
250
+ sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
251
+ self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0],
252
+ fuse_level=ConvSE3FuseLevel.FULL, **common_args)
253
+ #fuse_level=self.used_fuse_level, **common_args)
254
+ else:
255
+ # Use pairwise TFN convolutions
256
+ self.used_fuse_level = ConvSE3FuseLevel.NONE
257
+ self.conv = nn.ModuleDict()
258
+ for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
259
+ dict_key = f'{degree_in},{degree_out}'
260
+ channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
261
+ sum_freq = degree_to_dim(min(degree_in, degree_out))
262
+ self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
263
+ fuse_level=self.used_fuse_level, **common_args)
264
+
265
+ if self_interaction:
266
+ self.to_kernel_self = nn.ParameterDict()
267
+ for degree_out, channels_out in fiber_out:
268
+ if fiber_in[degree_out]:
269
+ self.to_kernel_self[str(degree_out)] = nn.Parameter(
270
+ torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
271
+
272
+ def forward(
273
+ self,
274
+ node_feats: Dict[str, Tensor],
275
+ edge_feats: Dict[str, Tensor],
276
+ graph: DGLGraph,
277
+ basis: Dict[str, Tensor]
278
+ ):
279
+ with nvtx_range(f'ConvSE3'):
280
+ invariant_edge_feats = edge_feats['0'].squeeze(-1)
281
+ src, dst = graph.edges()
282
+ out = {}
283
+ in_features = []
284
+
285
+ # Fetch all input features from edge and node features
286
+ for degree_in in self.fiber_in.degrees:
287
+ src_node_features = node_feats[str(degree_in)][src]
288
+ if degree_in > 0 and str(degree_in) in edge_feats:
289
+ # Handle edge features of any type by concatenating them to node features
290
+ src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
291
+ in_features.append(src_node_features)
292
+
293
+ if self.used_fuse_level == ConvSE3FuseLevel.FULL:
294
+ in_features_fused = torch.cat(in_features, dim=-1)
295
+ out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
296
+
297
+ if not self.allow_fused_output or self.self_interaction or self.pool:
298
+ out = unfuse_features(out, self.fiber_out.degrees)
299
+
300
+ elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
301
+ in_features_fused = torch.cat(in_features, dim=-1)
302
+ for degree_out in self.fiber_out.degrees:
303
+ out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats,
304
+ basis[f'out{degree_out}_fused'])
305
+
306
+ elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
307
+ out = 0
308
+ for degree_in, feature in zip(self.fiber_in.degrees, in_features):
309
+ out += self.conv_in[str(degree_in)](feature, invariant_edge_feats,
310
+ basis[f'in{degree_in}_fused'])
311
+ if not self.allow_fused_output or self.self_interaction or self.pool:
312
+ out = unfuse_features(out, self.fiber_out.degrees)
313
+ else:
314
+ # Fallback to pairwise TFN convolutions
315
+ for degree_out in self.fiber_out.degrees:
316
+ out_feature = 0
317
+ for degree_in, feature in zip(self.fiber_in.degrees, in_features):
318
+ dict_key = f'{degree_in},{degree_out}'
319
+ out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats,
320
+ basis.get(dict_key, None))
321
+ out[str(degree_out)] = out_feature
322
+
323
+ for degree_out in self.fiber_out.degrees:
324
+ if self.self_interaction and str(degree_out) in self.to_kernel_self:
325
+ with nvtx_range(f'self interaction'):
326
+ dst_features = node_feats[str(degree_out)][dst]
327
+ kernel_self = self.to_kernel_self[str(degree_out)]
328
+ out[str(degree_out)] += kernel_self @ dst_features
329
+
330
+ if self.pool:
331
+ with nvtx_range(f'pooling'):
332
+ if isinstance(out, dict):
333
+ out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
334
+ else:
335
+ out = dgl.ops.copy_e_sum(graph, out)
336
+ return out
env/SE3Transformer/build/lib/se3_transformer/model/layers/linear.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+
25
+ from typing import Dict
26
+
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
30
+ from torch import Tensor
31
+
32
+ from se3_transformer.model.fiber import Fiber
33
+
34
+
35
+ class LinearSE3(nn.Module):
36
+ """
37
+ Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
38
+ Maps a fiber to a fiber with the same degrees (channels may be different).
39
+ No interaction between degrees, but interaction between channels.
40
+
41
+ type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels)
42
+ type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels)
43
+ :
44
+ type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels)
45
+ """
46
+
47
+ def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
48
+ super().__init__()
49
+ self.weights = nn.ParameterDict({
50
+ str(degree_out): nn.Parameter(
51
+ torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
52
+ for degree_out, channels_out in fiber_out
53
+ })
54
+
55
+ def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
56
+ return {
57
+ degree: self.weights[degree] @ features[degree]
58
+ for degree, weight in self.weights.items()
59
+ }
env/SE3Transformer/build/lib/se3_transformer/model/layers/norm.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+
25
+ from typing import Dict
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ from torch import Tensor
30
+ from torch.cuda.nvtx import range as nvtx_range
31
+
32
+ from se3_transformer.model.fiber import Fiber
33
+
34
+
35
+ class NormSE3(nn.Module):
36
+ """
37
+ Norm-based SE(3)-equivariant nonlinearity.
38
+
39
+ ┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐
40
+ feature_in ──┤ * ──> feature_out
41
+ └──> feature_phase ────────────────────────────┘
42
+ """
43
+
44
+ NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
45
+
46
+ def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
47
+ super().__init__()
48
+ self.fiber = fiber
49
+ self.nonlinearity = nonlinearity
50
+
51
+ if len(set(fiber.channels)) == 1:
52
+ # Fuse all the layer normalizations into a group normalization
53
+ self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
54
+ else:
55
+ # Use multiple layer normalizations
56
+ self.layer_norms = nn.ModuleDict({
57
+ str(degree): nn.LayerNorm(channels)
58
+ for degree, channels in fiber
59
+ })
60
+
61
+ def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
62
+ with nvtx_range('NormSE3'):
63
+ output = {}
64
+ if hasattr(self, 'group_norm'):
65
+ # Compute per-degree norms of features
66
+ norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
67
+ for d in self.fiber.degrees]
68
+ fused_norms = torch.cat(norms, dim=-2)
69
+
70
+ # Transform the norms only
71
+ new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
72
+ new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
73
+
74
+ # Scale features to the new norms
75
+ for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
76
+ output[str(d)] = features[str(d)] / norm * new_norm
77
+ else:
78
+ for degree, feat in features.items():
79
+ norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
80
+ new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
81
+ output[degree] = new_norm * feat / norm
82
+
83
+ return output
env/SE3Transformer/build/lib/se3_transformer/model/layers/pooling.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ from typing import Dict, Literal
25
+
26
+ import torch.nn as nn
27
+ from dgl import DGLGraph
28
+ from dgl.nn.pytorch import AvgPooling, MaxPooling
29
+ from torch import Tensor
30
+
31
+
32
+ class GPooling(nn.Module):
33
+ """
34
+ Graph max/average pooling on a given feature type.
35
+ The average can be taken for any feature type, and equivariance will be maintained.
36
+ The maximum can only be taken for invariant features (type 0).
37
+ If you want max-pooling for type > 0 features, look into Vector Neurons.
38
+ """
39
+
40
+ def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
41
+ """
42
+ :param feat_type: Feature type to pool
43
+ :param pool: Type of pooling: max or avg
44
+ """
45
+ super().__init__()
46
+ assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
47
+ assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
48
+ self.feat_type = feat_type
49
+ self.pool = MaxPooling() if pool == 'max' else AvgPooling()
50
+
51
+ def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
52
+ pooled = self.pool(graph, features[str(self.feat_type)])
53
+ return pooled.squeeze(dim=-1)
env/SE3Transformer/build/lib/se3_transformer/model/transformer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import logging
25
+ from typing import Optional, Literal, Dict
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ from dgl import DGLGraph
30
+ from torch import Tensor
31
+
32
+ from se3_transformer.model.basis import get_basis, update_basis_with_fused
33
+ from se3_transformer.model.layers.attention import AttentionBlockSE3
34
+ from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
35
+ from se3_transformer.model.layers.norm import NormSE3
36
+ from se3_transformer.model.layers.pooling import GPooling
37
+ from se3_transformer.runtime.utils import str2bool
38
+ from se3_transformer.model.fiber import Fiber
39
+
40
+
41
+ class Sequential(nn.Sequential):
42
+ """ Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
43
+
44
+ def forward(self, input, *args, **kwargs):
45
+ for module in self:
46
+ input = module(input, *args, **kwargs)
47
+ return input
48
+
49
+
50
+ def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
51
+ """ Add relative positions to existing edge features """
52
+ edge_features = edge_features.copy() if edge_features else {}
53
+ r = relative_pos.norm(dim=-1, keepdim=True)
54
+ if '0' in edge_features:
55
+ edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
56
+ else:
57
+ edge_features['0'] = r[..., None]
58
+
59
+ return edge_features
60
+
61
+
62
+ class SE3Transformer(nn.Module):
63
+ def __init__(self,
64
+ num_layers: int,
65
+ fiber_in: Fiber,
66
+ fiber_hidden: Fiber,
67
+ fiber_out: Fiber,
68
+ num_heads: int,
69
+ channels_div: int,
70
+ fiber_edge: Fiber = Fiber({}),
71
+ return_type: Optional[int] = None,
72
+ pooling: Optional[Literal['avg', 'max']] = None,
73
+ norm: bool = True,
74
+ use_layer_norm: bool = True,
75
+ tensor_cores: bool = False,
76
+ low_memory: bool = False,
77
+ **kwargs):
78
+ """
79
+ :param num_layers: Number of attention layers
80
+ :param fiber_in: Input fiber description
81
+ :param fiber_hidden: Hidden fiber description
82
+ :param fiber_out: Output fiber description
83
+ :param fiber_edge: Input edge fiber description
84
+ :param num_heads: Number of attention heads
85
+ :param channels_div: Channels division before feeding to attention layer
86
+ :param return_type: Return only features of this type
87
+ :param pooling: 'avg' or 'max' graph pooling before MLP layers
88
+ :param norm: Apply a normalization layer after each attention block
89
+ :param use_layer_norm: Apply layer normalization between MLP layers
90
+ :param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
91
+ :param low_memory: If True, will use slower ops that use less memory
92
+ """
93
+ super().__init__()
94
+ self.num_layers = num_layers
95
+ self.fiber_edge = fiber_edge
96
+ self.num_heads = num_heads
97
+ self.channels_div = channels_div
98
+ self.return_type = return_type
99
+ self.pooling = pooling
100
+ self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
101
+ self.tensor_cores = tensor_cores
102
+ self.low_memory = low_memory
103
+
104
+ if low_memory and not tensor_cores:
105
+ logging.warning('Low memory mode will have no effect with no Tensor Cores')
106
+
107
+ # Fully fused convolutions when using Tensor Cores (and not low memory mode)
108
+ fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
109
+
110
+ graph_modules = []
111
+ for i in range(num_layers):
112
+ graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
113
+ fiber_out=fiber_hidden,
114
+ fiber_edge=fiber_edge,
115
+ num_heads=num_heads,
116
+ channels_div=channels_div,
117
+ use_layer_norm=use_layer_norm,
118
+ max_degree=self.max_degree,
119
+ fuse_level=fuse_level))
120
+ if norm:
121
+ graph_modules.append(NormSE3(fiber_hidden))
122
+ fiber_in = fiber_hidden
123
+
124
+ graph_modules.append(ConvSE3(fiber_in=fiber_in,
125
+ fiber_out=fiber_out,
126
+ fiber_edge=fiber_edge,
127
+ self_interaction=True,
128
+ use_layer_norm=use_layer_norm,
129
+ max_degree=self.max_degree))
130
+ self.graph_modules = Sequential(*graph_modules)
131
+
132
+ if pooling is not None:
133
+ assert return_type is not None, 'return_type must be specified when pooling'
134
+ self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
135
+
136
+ def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
137
+ edge_feats: Optional[Dict[str, Tensor]] = None,
138
+ basis: Optional[Dict[str, Tensor]] = None):
139
+ # Compute bases in case they weren't precomputed as part of the data loading
140
+ basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
141
+ use_pad_trick=self.tensor_cores and not self.low_memory,
142
+ amp=torch.is_autocast_enabled())
143
+
144
+ # Add fused bases (per output degree, per input degree, and fully fused) to the dict
145
+ basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
146
+ fully_fused=self.tensor_cores and not self.low_memory)
147
+
148
+ edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
149
+
150
+ node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
151
+
152
+ if self.pooling is not None:
153
+ return self.pooling_module(node_feats, graph=graph)
154
+
155
+ if self.return_type is not None:
156
+ return node_feats[str(self.return_type)]
157
+
158
+ return node_feats
159
+
160
+ @staticmethod
161
+ def add_argparse_args(parser):
162
+ parser.add_argument('--num_layers', type=int, default=7,
163
+ help='Number of stacked Transformer layers')
164
+ parser.add_argument('--num_heads', type=int, default=8,
165
+ help='Number of heads in self-attention')
166
+ parser.add_argument('--channels_div', type=int, default=2,
167
+ help='Channels division before feeding to attention layer')
168
+ parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
169
+ help='Type of graph pooling')
170
+ parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
171
+ help='Apply a normalization layer after each attention block')
172
+ parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
173
+ help='Apply layer normalization between MLP layers')
174
+ parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
175
+ help='If true, will use fused ops that are slower but that use less memory '
176
+ '(expect 25 percent less memory). '
177
+ 'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
178
+
179
+ return parser
180
+
181
+
182
+ class SE3TransformerPooled(nn.Module):
183
+ def __init__(self,
184
+ fiber_in: Fiber,
185
+ fiber_out: Fiber,
186
+ fiber_edge: Fiber,
187
+ num_degrees: int,
188
+ num_channels: int,
189
+ output_dim: int,
190
+ **kwargs):
191
+ super().__init__()
192
+ kwargs['pooling'] = kwargs['pooling'] or 'max'
193
+ self.transformer = SE3Transformer(
194
+ fiber_in=fiber_in,
195
+ fiber_hidden=Fiber.create(num_degrees, num_channels),
196
+ fiber_out=fiber_out,
197
+ fiber_edge=fiber_edge,
198
+ return_type=0,
199
+ **kwargs
200
+ )
201
+
202
+ n_out_features = fiber_out.num_features
203
+ self.mlp = nn.Sequential(
204
+ nn.Linear(n_out_features, n_out_features),
205
+ nn.ReLU(),
206
+ nn.Linear(n_out_features, output_dim)
207
+ )
208
+
209
+ def forward(self, graph, node_feats, edge_feats, basis=None):
210
+ feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
211
+ y = self.mlp(feats).squeeze(-1)
212
+ return y
213
+
214
+ @staticmethod
215
+ def add_argparse_args(parent_parser):
216
+ parser = parent_parser.add_argument_group("Model architecture")
217
+ SE3Transformer.add_argparse_args(parser)
218
+ parser.add_argument('--num_degrees',
219
+ help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
220
+ type=int, default=4)
221
+ parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
222
+ return parent_parser
env/SE3Transformer/build/lib/se3_transformer/runtime/__init__.py ADDED
File without changes
env/SE3Transformer/build/lib/se3_transformer/runtime/arguments.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import argparse
25
+ import pathlib
26
+
27
+ from se3_transformer.data_loading import QM9DataModule
28
+ from se3_transformer.model import SE3TransformerPooled
29
+ from se3_transformer.runtime.utils import str2bool
30
+
31
+ PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
32
+
33
+ paths = PARSER.add_argument_group('Paths')
34
+ paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
35
+ help='Directory where the data is located or should be downloaded')
36
+ paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
37
+ help='Directory where the results logs should be saved')
38
+ paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
39
+ help='Name for the resulting DLLogger JSON file')
40
+ paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
41
+ help='File where the checkpoint should be saved')
42
+ paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
43
+ help='File of the checkpoint to be loaded')
44
+
45
+ optimizer = PARSER.add_argument_group('Optimizer')
46
+ optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
47
+ optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
48
+ optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
49
+ optimizer.add_argument('--momentum', type=float, default=0.9)
50
+ optimizer.add_argument('--weight_decay', type=float, default=0.1)
51
+
52
+ PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
53
+ PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
54
+ PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
55
+ PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
56
+
57
+ PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
58
+ PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
59
+ PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
60
+ PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
61
+ PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
62
+ help='Do an evaluation round every N epochs')
63
+ PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
64
+ help='Minimize stdout output')
65
+
66
+ PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
67
+ help='Benchmark mode')
68
+
69
+ QM9DataModule.add_argparse_args(PARSER)
70
+ SE3TransformerPooled.add_argparse_args(PARSER)
env/SE3Transformer/build/lib/se3_transformer/runtime/callbacks.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import logging
25
+ import time
26
+ from abc import ABC, abstractmethod
27
+ from typing import Optional
28
+
29
+ import numpy as np
30
+ import torch
31
+
32
+ from se3_transformer.runtime.loggers import Logger
33
+ from se3_transformer.runtime.metrics import MeanAbsoluteError
34
+
35
+
36
+ class BaseCallback(ABC):
37
+ def on_fit_start(self, optimizer, args):
38
+ pass
39
+
40
+ def on_fit_end(self):
41
+ pass
42
+
43
+ def on_epoch_end(self):
44
+ pass
45
+
46
+ def on_batch_start(self):
47
+ pass
48
+
49
+ def on_validation_step(self, input, target, pred):
50
+ pass
51
+
52
+ def on_validation_end(self, epoch=None):
53
+ pass
54
+
55
+ def on_checkpoint_load(self, checkpoint):
56
+ pass
57
+
58
+ def on_checkpoint_save(self, checkpoint):
59
+ pass
60
+
61
+
62
+ class LRSchedulerCallback(BaseCallback):
63
+ def __init__(self, logger: Optional[Logger] = None):
64
+ self.logger = logger
65
+ self.scheduler = None
66
+
67
+ @abstractmethod
68
+ def get_scheduler(self, optimizer, args):
69
+ pass
70
+
71
+ def on_fit_start(self, optimizer, args):
72
+ self.scheduler = self.get_scheduler(optimizer, args)
73
+
74
+ def on_checkpoint_load(self, checkpoint):
75
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
76
+
77
+ def on_checkpoint_save(self, checkpoint):
78
+ checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
79
+
80
+ def on_epoch_end(self):
81
+ if self.logger is not None:
82
+ self.logger.log_metrics({'learning rate': self.scheduler.get_last_lr()[0]}, step=self.scheduler.last_epoch)
83
+ self.scheduler.step()
84
+
85
+
86
+ class QM9MetricCallback(BaseCallback):
87
+ """ Logs the rescaled mean absolute error for QM9 regression tasks """
88
+
89
+ def __init__(self, logger, targets_std, prefix=''):
90
+ self.mae = MeanAbsoluteError()
91
+ self.logger = logger
92
+ self.targets_std = targets_std
93
+ self.prefix = prefix
94
+ self.best_mae = float('inf')
95
+
96
+ def on_validation_step(self, input, target, pred):
97
+ self.mae(pred.detach(), target.detach())
98
+
99
+ def on_validation_end(self, epoch=None):
100
+ mae = self.mae.compute() * self.targets_std
101
+ logging.info(f'{self.prefix} MAE: {mae}')
102
+ self.logger.log_metrics({f'{self.prefix} MAE': mae}, epoch)
103
+ self.best_mae = min(self.best_mae, mae)
104
+
105
+ def on_fit_end(self):
106
+ if self.best_mae != float('inf'):
107
+ self.logger.log_metrics({f'{self.prefix} best MAE': self.best_mae})
108
+
109
+
110
+ class QM9LRSchedulerCallback(LRSchedulerCallback):
111
+ def __init__(self, logger, epochs):
112
+ super().__init__(logger)
113
+ self.epochs = epochs
114
+
115
+ def get_scheduler(self, optimizer, args):
116
+ min_lr = args.min_learning_rate if args.min_learning_rate else args.learning_rate / 10.0
117
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, self.epochs, eta_min=min_lr)
118
+
119
+
120
+ class PerformanceCallback(BaseCallback):
121
+ def __init__(self, logger, batch_size: int, warmup_epochs: int = 1, mode: str = 'train'):
122
+ self.batch_size = batch_size
123
+ self.warmup_epochs = warmup_epochs
124
+ self.epoch = 0
125
+ self.timestamps = []
126
+ self.mode = mode
127
+ self.logger = logger
128
+
129
+ def on_batch_start(self):
130
+ if self.epoch >= self.warmup_epochs:
131
+ self.timestamps.append(time.time() * 1000.0)
132
+
133
+ def _log_perf(self):
134
+ stats = self.process_performance_stats()
135
+ for k, v in stats.items():
136
+ logging.info(f'performance {k}: {v}')
137
+
138
+ self.logger.log_metrics(stats)
139
+
140
+ def on_epoch_end(self):
141
+ self.epoch += 1
142
+
143
+ def on_fit_end(self):
144
+ if self.epoch > self.warmup_epochs:
145
+ self._log_perf()
146
+ self.timestamps = []
147
+
148
+ def process_performance_stats(self):
149
+ timestamps = np.asarray(self.timestamps)
150
+ deltas = np.diff(timestamps)
151
+ throughput = (self.batch_size / deltas).mean()
152
+ stats = {
153
+ f"throughput_{self.mode}": throughput,
154
+ f"latency_{self.mode}_mean": deltas.mean(),
155
+ f"total_time_{self.mode}": timestamps[-1] - timestamps[0],
156
+ }
157
+ for level in [90, 95, 99]:
158
+ stats.update({f"latency_{self.mode}_{level}": np.percentile(deltas, level)})
159
+
160
+ return stats
env/SE3Transformer/build/lib/se3_transformer/runtime/gpu_affinity.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import collections
25
+ import itertools
26
+ import math
27
+ import os
28
+ import pathlib
29
+ import re
30
+
31
+ import pynvml
32
+
33
+
34
+ class Device:
35
+ # assumes nvml returns list of 64 bit ints
36
+ _nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
37
+
38
+ def __init__(self, device_idx):
39
+ super().__init__()
40
+ self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
41
+
42
+ def get_name(self):
43
+ return pynvml.nvmlDeviceGetName(self.handle)
44
+
45
+ def get_uuid(self):
46
+ return pynvml.nvmlDeviceGetUUID(self.handle)
47
+
48
+ def get_cpu_affinity(self):
49
+ affinity_string = ""
50
+ for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
51
+ # assume nvml returns list of 64 bit ints
52
+ affinity_string = "{:064b}".format(j) + affinity_string
53
+
54
+ affinity_list = [int(x) for x in affinity_string]
55
+ affinity_list.reverse() # so core 0 is in 0th element of list
56
+
57
+ ret = [i for i, e in enumerate(affinity_list) if e != 0]
58
+ return ret
59
+
60
+
61
+ def get_thread_siblings_list():
62
+ """
63
+ Returns a list of 2-element integer tuples representing pairs of
64
+ hyperthreading cores.
65
+ """
66
+ path = "/sys/devices/system/cpu/cpu*/topology/thread_siblings_list"
67
+ thread_siblings_list = []
68
+ pattern = re.compile(r"(\d+)\D(\d+)")
69
+ for fname in pathlib.Path(path[0]).glob(path[1:]):
70
+ with open(fname) as f:
71
+ content = f.read().strip()
72
+ res = pattern.findall(content)
73
+ if res:
74
+ pair = tuple(map(int, res[0]))
75
+ thread_siblings_list.append(pair)
76
+ return thread_siblings_list
77
+
78
+
79
+ def check_socket_affinities(socket_affinities):
80
+ # sets of cores should be either identical or disjoint
81
+ for i, j in itertools.product(socket_affinities, socket_affinities):
82
+ if not set(i) == set(j) and not set(i).isdisjoint(set(j)):
83
+ raise RuntimeError(f"Sets of cores should be either identical or disjoint, " f"but got {i} and {j}.")
84
+
85
+
86
+ def get_socket_affinities(nproc_per_node, exclude_unavailable_cores=True):
87
+ devices = [Device(i) for i in range(nproc_per_node)]
88
+ socket_affinities = [dev.get_cpu_affinity() for dev in devices]
89
+
90
+ if exclude_unavailable_cores:
91
+ available_cores = os.sched_getaffinity(0)
92
+ socket_affinities = [list(set(affinity) & available_cores) for affinity in socket_affinities]
93
+
94
+ check_socket_affinities(socket_affinities)
95
+
96
+ return socket_affinities
97
+
98
+
99
+ def set_socket_affinity(gpu_id):
100
+ """
101
+ The process is assigned with all available logical CPU cores from the CPU
102
+ socket connected to the GPU with a given id.
103
+
104
+ Args:
105
+ gpu_id: index of a GPU
106
+ """
107
+ dev = Device(gpu_id)
108
+ affinity = dev.get_cpu_affinity()
109
+ os.sched_setaffinity(0, affinity)
110
+
111
+
112
+ def set_single_affinity(gpu_id):
113
+ """
114
+ The process is assigned with the first available logical CPU core from the
115
+ list of all CPU cores from the CPU socket connected to the GPU with a given
116
+ id.
117
+
118
+ Args:
119
+ gpu_id: index of a GPU
120
+ """
121
+ dev = Device(gpu_id)
122
+ affinity = dev.get_cpu_affinity()
123
+
124
+ # exclude unavailable cores
125
+ available_cores = os.sched_getaffinity(0)
126
+ affinity = list(set(affinity) & available_cores)
127
+ os.sched_setaffinity(0, affinity[:1])
128
+
129
+
130
+ def set_single_unique_affinity(gpu_id, nproc_per_node):
131
+ """
132
+ The process is assigned with a single unique available physical CPU core
133
+ from the list of all CPU cores from the CPU socket connected to the GPU with
134
+ a given id.
135
+
136
+ Args:
137
+ gpu_id: index of a GPU
138
+ """
139
+ socket_affinities = get_socket_affinities(nproc_per_node)
140
+
141
+ siblings_list = get_thread_siblings_list()
142
+ siblings_dict = dict(siblings_list)
143
+
144
+ # remove siblings
145
+ for idx, socket_affinity in enumerate(socket_affinities):
146
+ socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
147
+
148
+ affinities = []
149
+ assigned = []
150
+
151
+ for socket_affinity in socket_affinities:
152
+ for core in socket_affinity:
153
+ if core not in assigned:
154
+ affinities.append([core])
155
+ assigned.append(core)
156
+ break
157
+ os.sched_setaffinity(0, affinities[gpu_id])
158
+
159
+
160
+ def set_socket_unique_affinity(gpu_id, nproc_per_node, mode, balanced=True):
161
+ """
162
+ The process is assigned with an unique subset of available physical CPU
163
+ cores from the CPU socket connected to a GPU with a given id.
164
+ Assignment automatically includes hyperthreading siblings (if siblings are
165
+ available).
166
+
167
+ Args:
168
+ gpu_id: index of a GPU
169
+ nproc_per_node: total number of processes per node
170
+ mode: mode
171
+ balanced: assign an equal number of physical cores to each process
172
+ """
173
+ socket_affinities = get_socket_affinities(nproc_per_node)
174
+
175
+ siblings_list = get_thread_siblings_list()
176
+ siblings_dict = dict(siblings_list)
177
+
178
+ # remove hyperthreading siblings
179
+ for idx, socket_affinity in enumerate(socket_affinities):
180
+ socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
181
+
182
+ socket_affinities_to_device_ids = collections.defaultdict(list)
183
+
184
+ for idx, socket_affinity in enumerate(socket_affinities):
185
+ socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
186
+
187
+ # compute minimal number of physical cores per GPU across all GPUs and
188
+ # sockets, code assigns this number of cores per GPU if balanced == True
189
+ min_physical_cores_per_gpu = min(
190
+ [len(cores) // len(gpus) for cores, gpus in socket_affinities_to_device_ids.items()]
191
+ )
192
+
193
+ for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
194
+ devices_per_group = len(device_ids)
195
+ if balanced:
196
+ cores_per_device = min_physical_cores_per_gpu
197
+ socket_affinity = socket_affinity[: devices_per_group * min_physical_cores_per_gpu]
198
+ else:
199
+ cores_per_device = len(socket_affinity) // devices_per_group
200
+
201
+ for group_id, device_id in enumerate(device_ids):
202
+ if device_id == gpu_id:
203
+
204
+ # In theory there should be no difference in performance between
205
+ # 'interleaved' and 'continuous' pattern on Intel-based DGX-1,
206
+ # but 'continuous' should be better for DGX A100 because on AMD
207
+ # Rome 4 consecutive cores are sharing L3 cache.
208
+ # TODO: code doesn't attempt to automatically detect layout of
209
+ # L3 cache, also external environment may already exclude some
210
+ # cores, this code makes no attempt to detect it and to align
211
+ # mapping to multiples of 4.
212
+
213
+ if mode == "interleaved":
214
+ affinity = list(socket_affinity[group_id::devices_per_group])
215
+ elif mode == "continuous":
216
+ affinity = list(socket_affinity[group_id * cores_per_device: (group_id + 1) * cores_per_device])
217
+ else:
218
+ raise RuntimeError("Unknown set_socket_unique_affinity mode")
219
+
220
+ # unconditionally reintroduce hyperthreading siblings, this step
221
+ # may result in a different numbers of logical cores assigned to
222
+ # each GPU even if balanced == True (if hyperthreading siblings
223
+ # aren't available for a subset of cores due to some external
224
+ # constraints, siblings are re-added unconditionally, in the
225
+ # worst case unavailable logical core will be ignored by
226
+ # os.sched_setaffinity().
227
+ affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
228
+ os.sched_setaffinity(0, affinity)
229
+
230
+
231
+ def set_affinity(gpu_id, nproc_per_node, mode="socket_unique_continuous", balanced=True):
232
+ """
233
+ The process is assigned with a proper CPU affinity which matches hardware
234
+ architecture on a given platform. Usually it improves and stabilizes
235
+ performance of deep learning training workloads.
236
+
237
+ This function assumes that the workload is running in multi-process
238
+ single-device mode (there are multiple training processes and each process
239
+ is running on a single GPU), which is typical for multi-GPU training
240
+ workloads using `torch.nn.parallel.DistributedDataParallel`.
241
+
242
+ Available affinity modes:
243
+ * 'socket' - the process is assigned with all available logical CPU cores
244
+ from the CPU socket connected to the GPU with a given id.
245
+ * 'single' - the process is assigned with the first available logical CPU
246
+ core from the list of all CPU cores from the CPU socket connected to the GPU
247
+ with a given id (multiple GPUs could be assigned with the same CPU core).
248
+ * 'single_unique' - the process is assigned with a single unique available
249
+ physical CPU core from the list of all CPU cores from the CPU socket
250
+ connected to the GPU with a given id.
251
+ * 'socket_unique_interleaved' - the process is assigned with an unique
252
+ subset of available physical CPU cores from the CPU socket connected to a
253
+ GPU with a given id, hyperthreading siblings are included automatically,
254
+ cores are assigned with interleaved indexing pattern
255
+ * 'socket_unique_continuous' - (the default) the process is assigned with an
256
+ unique subset of available physical CPU cores from the CPU socket connected
257
+ to a GPU with a given id, hyperthreading siblings are included
258
+ automatically, cores are assigned with continuous indexing pattern
259
+
260
+ 'socket_unique_continuous' is the recommended mode for deep learning
261
+ training workloads on NVIDIA DGX machines.
262
+
263
+ Args:
264
+ gpu_id: integer index of a GPU
265
+ nproc_per_node: number of processes per node
266
+ mode: affinity mode
267
+ balanced: assign an equal number of physical cores to each process,
268
+ affects only 'socket_unique_interleaved' and
269
+ 'socket_unique_continuous' affinity modes
270
+
271
+ Returns a set of logical CPU cores on which the process is eligible to run.
272
+
273
+ Example:
274
+
275
+ import argparse
276
+ import os
277
+
278
+ import gpu_affinity
279
+ import torch
280
+
281
+
282
+ def main():
283
+ parser = argparse.ArgumentParser()
284
+ parser.add_argument(
285
+ '--local_rank',
286
+ type=int,
287
+ default=os.getenv('LOCAL_RANK', 0),
288
+ )
289
+ args = parser.parse_args()
290
+
291
+ nproc_per_node = torch.cuda.device_count()
292
+
293
+ affinity = gpu_affinity.set_affinity(args.local_rank, nproc_per_node)
294
+ print(f'{args.local_rank}: core affinity: {affinity}')
295
+
296
+
297
+ if __name__ == "__main__":
298
+ main()
299
+
300
+ Launch the example with:
301
+ python -m torch.distributed.launch --nproc_per_node <#GPUs> example.py
302
+
303
+
304
+ WARNING: On DGX A100 only a half of CPU cores have direct access to GPUs.
305
+ This function restricts execution only to the CPU cores directly connected
306
+ to GPUs, so on DGX A100 it will limit the code to half of CPU cores and half
307
+ of CPU memory bandwidth (which may be fine for many DL models).
308
+ """
309
+ pynvml.nvmlInit()
310
+
311
+ if mode == "socket":
312
+ set_socket_affinity(gpu_id)
313
+ elif mode == "single":
314
+ set_single_affinity(gpu_id)
315
+ elif mode == "single_unique":
316
+ set_single_unique_affinity(gpu_id, nproc_per_node)
317
+ elif mode == "socket_unique_interleaved":
318
+ set_socket_unique_affinity(gpu_id, nproc_per_node, "interleaved", balanced)
319
+ elif mode == "socket_unique_continuous":
320
+ set_socket_unique_affinity(gpu_id, nproc_per_node, "continuous", balanced)
321
+ else:
322
+ raise RuntimeError("Unknown affinity mode")
323
+
324
+ affinity = os.sched_getaffinity(0)
325
+ return affinity
env/SE3Transformer/build/lib/se3_transformer/runtime/inference.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ from typing import List
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn.parallel import DistributedDataParallel
29
+ from torch.utils.data import DataLoader
30
+ from tqdm import tqdm
31
+
32
+ from se3_transformer.runtime import gpu_affinity
33
+ from se3_transformer.runtime.arguments import PARSER
34
+ from se3_transformer.runtime.callbacks import BaseCallback
35
+ from se3_transformer.runtime.loggers import DLLogger
36
+ from se3_transformer.runtime.utils import to_cuda, get_local_rank
37
+
38
+
39
+ @torch.inference_mode()
40
+ def evaluate(model: nn.Module,
41
+ dataloader: DataLoader,
42
+ callbacks: List[BaseCallback],
43
+ args):
44
+ model.eval()
45
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), unit='batch', desc=f'Evaluation',
46
+ leave=False, disable=(args.silent or get_local_rank() != 0)):
47
+ *input, target = to_cuda(batch)
48
+
49
+ for callback in callbacks:
50
+ callback.on_batch_start()
51
+
52
+ with torch.cuda.amp.autocast(enabled=args.amp):
53
+ pred = model(*input)
54
+
55
+ for callback in callbacks:
56
+ callback.on_validation_step(input, target, pred)
57
+
58
+
59
+ if __name__ == '__main__':
60
+ from se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
61
+ from se3_transformer.runtime.utils import init_distributed, seed_everything
62
+ from se3_transformer.model import SE3TransformerPooled, Fiber
63
+ from se3_transformer.data_loading import QM9DataModule
64
+ import torch.distributed as dist
65
+ import logging
66
+ import sys
67
+
68
+ is_distributed = init_distributed()
69
+ local_rank = get_local_rank()
70
+ args = PARSER.parse_args()
71
+
72
+ logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
73
+
74
+ logging.info('====== SE(3)-Transformer ======')
75
+ logging.info('| Inference on the test set |')
76
+ logging.info('===============================')
77
+
78
+ if not args.benchmark and args.load_ckpt_path is None:
79
+ logging.error('No load_ckpt_path provided, you need to provide a saved model to evaluate')
80
+ sys.exit(1)
81
+
82
+ if args.benchmark:
83
+ logging.info('Running benchmark mode with one warmup pass')
84
+
85
+ if args.seed is not None:
86
+ seed_everything(args.seed)
87
+
88
+ major_cc, minor_cc = torch.cuda.get_device_capability()
89
+
90
+ logger = DLLogger(args.log_dir, filename=args.dllogger_name)
91
+ datamodule = QM9DataModule(**vars(args))
92
+ model = SE3TransformerPooled(
93
+ fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
94
+ fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
95
+ fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
96
+ output_dim=1,
97
+ tensor_cores=(args.amp and major_cc >= 7) or major_cc >= 8, # use Tensor Cores more effectively
98
+ **vars(args)
99
+ )
100
+ callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='test')]
101
+
102
+ model.to(device=torch.cuda.current_device())
103
+ if args.load_ckpt_path is not None:
104
+ checkpoint = torch.load(str(args.load_ckpt_path), map_location={'cuda:0': f'cuda:{local_rank}'})
105
+ model.load_state_dict(checkpoint['state_dict'])
106
+
107
+ if is_distributed:
108
+ nproc_per_node = torch.cuda.device_count()
109
+ affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
110
+ model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
111
+
112
+ test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
113
+ evaluate(model,
114
+ test_dataloader,
115
+ callbacks,
116
+ args)
117
+
118
+ for callback in callbacks:
119
+ callback.on_validation_end()
120
+
121
+ if args.benchmark:
122
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
123
+ callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')]
124
+ for _ in range(6):
125
+ evaluate(model,
126
+ test_dataloader,
127
+ callbacks,
128
+ args)
129
+ callbacks[0].on_epoch_end()
130
+
131
+ callbacks[0].on_fit_end()
env/SE3Transformer/build/lib/se3_transformer/runtime/loggers.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import pathlib
25
+ from abc import ABC, abstractmethod
26
+ from enum import Enum
27
+ from typing import Dict, Any, Callable, Optional
28
+
29
+ import dllogger
30
+ import torch.distributed as dist
31
+ import wandb
32
+ from dllogger import Verbosity
33
+
34
+ from se3_transformer.runtime.utils import rank_zero_only
35
+
36
+
37
+ class Logger(ABC):
38
+ @rank_zero_only
39
+ @abstractmethod
40
+ def log_hyperparams(self, params):
41
+ pass
42
+
43
+ @rank_zero_only
44
+ @abstractmethod
45
+ def log_metrics(self, metrics, step=None):
46
+ pass
47
+
48
+ @staticmethod
49
+ def _sanitize_params(params):
50
+ def _sanitize(val):
51
+ if isinstance(val, Callable):
52
+ try:
53
+ _val = val()
54
+ if isinstance(_val, Callable):
55
+ return val.__name__
56
+ return _val
57
+ except Exception:
58
+ return getattr(val, "__name__", None)
59
+ elif isinstance(val, pathlib.Path) or isinstance(val, Enum):
60
+ return str(val)
61
+ return val
62
+
63
+ return {key: _sanitize(val) for key, val in params.items()}
64
+
65
+
66
+ class LoggerCollection(Logger):
67
+ def __init__(self, loggers):
68
+ super().__init__()
69
+ self.loggers = loggers
70
+
71
+ def __getitem__(self, index):
72
+ return [logger for logger in self.loggers][index]
73
+
74
+ @rank_zero_only
75
+ def log_metrics(self, metrics, step=None):
76
+ for logger in self.loggers:
77
+ logger.log_metrics(metrics, step)
78
+
79
+ @rank_zero_only
80
+ def log_hyperparams(self, params):
81
+ for logger in self.loggers:
82
+ logger.log_hyperparams(params)
83
+
84
+
85
+ class DLLogger(Logger):
86
+ def __init__(self, save_dir: pathlib.Path, filename: str):
87
+ super().__init__()
88
+ if not dist.is_initialized() or dist.get_rank() == 0:
89
+ save_dir.mkdir(parents=True, exist_ok=True)
90
+ dllogger.init(
91
+ backends=[dllogger.JSONStreamBackend(Verbosity.DEFAULT, str(save_dir / filename))])
92
+
93
+ @rank_zero_only
94
+ def log_hyperparams(self, params):
95
+ params = self._sanitize_params(params)
96
+ dllogger.log(step="PARAMETER", data=params)
97
+
98
+ @rank_zero_only
99
+ def log_metrics(self, metrics, step=None):
100
+ if step is None:
101
+ step = tuple()
102
+
103
+ dllogger.log(step=step, data=metrics)
104
+
105
+
106
+ class WandbLogger(Logger):
107
+ def __init__(
108
+ self,
109
+ name: str,
110
+ save_dir: pathlib.Path,
111
+ id: Optional[str] = None,
112
+ project: Optional[str] = None
113
+ ):
114
+ super().__init__()
115
+ if not dist.is_initialized() or dist.get_rank() == 0:
116
+ save_dir.mkdir(parents=True, exist_ok=True)
117
+ self.experiment = wandb.init(name=name,
118
+ project=project,
119
+ id=id,
120
+ dir=str(save_dir),
121
+ resume='allow',
122
+ anonymous='must')
123
+
124
+ @rank_zero_only
125
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
126
+ params = self._sanitize_params(params)
127
+ self.experiment.config.update(params, allow_val_change=True)
128
+
129
+ @rank_zero_only
130
+ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
131
+ if step is not None:
132
+ self.experiment.log({**metrics, 'epoch': step})
133
+ else:
134
+ self.experiment.log(metrics)
env/SE3Transformer/build/lib/se3_transformer/runtime/metrics.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ from abc import ABC, abstractmethod
25
+
26
+ import torch
27
+ import torch.distributed as dist
28
+ from torch import Tensor
29
+
30
+
31
+ class Metric(ABC):
32
+ """ Metric class with synchronization capabilities similar to TorchMetrics """
33
+
34
+ def __init__(self):
35
+ self.states = {}
36
+
37
+ def add_state(self, name: str, default: Tensor):
38
+ assert name not in self.states
39
+ self.states[name] = default.clone()
40
+ setattr(self, name, default)
41
+
42
+ def synchronize(self):
43
+ if dist.is_initialized():
44
+ for state in self.states:
45
+ dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD)
46
+
47
+ def __call__(self, *args, **kwargs):
48
+ self.update(*args, **kwargs)
49
+
50
+ def reset(self):
51
+ for name, default in self.states.items():
52
+ setattr(self, name, default.clone())
53
+
54
+ def compute(self):
55
+ self.synchronize()
56
+ value = self._compute().item()
57
+ self.reset()
58
+ return value
59
+
60
+ @abstractmethod
61
+ def _compute(self):
62
+ pass
63
+
64
+ @abstractmethod
65
+ def update(self, preds: Tensor, targets: Tensor):
66
+ pass
67
+
68
+
69
+ class MeanAbsoluteError(Metric):
70
+ def __init__(self):
71
+ super().__init__()
72
+ self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda'))
73
+ self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda'))
74
+
75
+ def update(self, preds: Tensor, targets: Tensor):
76
+ preds = preds.detach()
77
+ n = preds.shape[0]
78
+ error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum()
79
+ self.total += n
80
+ self.error += error
81
+
82
+ def _compute(self):
83
+ return self.error / self.total
env/SE3Transformer/build/lib/se3_transformer/runtime/training.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import logging
25
+ import pathlib
26
+ from typing import List
27
+
28
+ import numpy as np
29
+ import torch
30
+ import torch.distributed as dist
31
+ import torch.nn as nn
32
+ from apex.optimizers import FusedAdam, FusedLAMB
33
+ from torch.nn.modules.loss import _Loss
34
+ from torch.nn.parallel import DistributedDataParallel
35
+ from torch.optim import Optimizer
36
+ from torch.utils.data import DataLoader, DistributedSampler
37
+ from tqdm import tqdm
38
+
39
+ from se3_transformer.data_loading import QM9DataModule
40
+ from se3_transformer.model import SE3TransformerPooled
41
+ from se3_transformer.model.fiber import Fiber
42
+ from se3_transformer.runtime import gpu_affinity
43
+ from se3_transformer.runtime.arguments import PARSER
44
+ from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
45
+ PerformanceCallback
46
+ from se3_transformer.runtime.inference import evaluate
47
+ from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
48
+ from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
49
+ using_tensor_cores, increase_l2_fetch_granularity
50
+
51
+
52
+ def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
53
+ """ Saves model, optimizer and epoch states to path (only once per node) """
54
+ if get_local_rank() == 0:
55
+ state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
56
+ checkpoint = {
57
+ 'state_dict': state_dict,
58
+ 'optimizer_state_dict': optimizer.state_dict(),
59
+ 'epoch': epoch
60
+ }
61
+ for callback in callbacks:
62
+ callback.on_checkpoint_save(checkpoint)
63
+
64
+ torch.save(checkpoint, str(path))
65
+ logging.info(f'Saved checkpoint to {str(path)}')
66
+
67
+
68
+ def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
69
+ """ Loads model, optimizer and epoch states from path """
70
+ checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
71
+ if isinstance(model, DistributedDataParallel):
72
+ model.module.load_state_dict(checkpoint['state_dict'])
73
+ else:
74
+ model.load_state_dict(checkpoint['state_dict'])
75
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
76
+
77
+ for callback in callbacks:
78
+ callback.on_checkpoint_load(checkpoint)
79
+
80
+ logging.info(f'Loaded checkpoint from {str(path)}')
81
+ return checkpoint['epoch']
82
+
83
+
84
+ def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
85
+ losses = []
86
+ for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
87
+ desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
88
+ *inputs, target = to_cuda(batch)
89
+
90
+ for callback in callbacks:
91
+ callback.on_batch_start()
92
+
93
+ with torch.cuda.amp.autocast(enabled=args.amp):
94
+ pred = model(*inputs)
95
+ loss = loss_fn(pred, target) / args.accumulate_grad_batches
96
+
97
+ grad_scaler.scale(loss).backward()
98
+
99
+ # gradient accumulation
100
+ if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
101
+ if args.gradient_clip:
102
+ grad_scaler.unscale_(optimizer)
103
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
104
+
105
+ grad_scaler.step(optimizer)
106
+ grad_scaler.update()
107
+ optimizer.zero_grad()
108
+
109
+ losses.append(loss.item())
110
+
111
+ return np.mean(losses)
112
+
113
+
114
+ def train(model: nn.Module,
115
+ loss_fn: _Loss,
116
+ train_dataloader: DataLoader,
117
+ val_dataloader: DataLoader,
118
+ callbacks: List[BaseCallback],
119
+ logger: Logger,
120
+ args):
121
+ device = torch.cuda.current_device()
122
+ model.to(device=device)
123
+ local_rank = get_local_rank()
124
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
125
+
126
+ if dist.is_initialized():
127
+ model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
128
+
129
+ model.train()
130
+ grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
131
+ if args.optimizer == 'adam':
132
+ optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
133
+ weight_decay=args.weight_decay)
134
+ elif args.optimizer == 'lamb':
135
+ optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
136
+ weight_decay=args.weight_decay)
137
+ else:
138
+ optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
139
+ weight_decay=args.weight_decay)
140
+
141
+ epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
142
+
143
+ for callback in callbacks:
144
+ callback.on_fit_start(optimizer, args)
145
+
146
+ for epoch_idx in range(epoch_start, args.epochs):
147
+ if isinstance(train_dataloader.sampler, DistributedSampler):
148
+ train_dataloader.sampler.set_epoch(epoch_idx)
149
+
150
+ loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
151
+ if dist.is_initialized():
152
+ loss = torch.tensor(loss, dtype=torch.float, device=device)
153
+ torch.distributed.all_reduce(loss)
154
+ loss = (loss / world_size).item()
155
+
156
+ logging.info(f'Train loss: {loss}')
157
+ logger.log_metrics({'train loss': loss}, epoch_idx)
158
+
159
+ for callback in callbacks:
160
+ callback.on_epoch_end()
161
+
162
+ if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
163
+ and (epoch_idx + 1) % args.ckpt_interval == 0:
164
+ save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
165
+
166
+ if not args.benchmark and args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0:
167
+ evaluate(model, val_dataloader, callbacks, args)
168
+ model.train()
169
+
170
+ for callback in callbacks:
171
+ callback.on_validation_end(epoch_idx)
172
+
173
+ if args.save_ckpt_path is not None and not args.benchmark:
174
+ save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
175
+
176
+ for callback in callbacks:
177
+ callback.on_fit_end()
178
+
179
+
180
+ def print_parameters_count(model):
181
+ num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
182
+ logging.info(f'Number of trainable parameters: {num_params_trainable}')
183
+
184
+
185
+ if __name__ == '__main__':
186
+ is_distributed = init_distributed()
187
+ local_rank = get_local_rank()
188
+ args = PARSER.parse_args()
189
+
190
+ logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
191
+
192
+ logging.info('====== SE(3)-Transformer ======')
193
+ logging.info('| Training procedure |')
194
+ logging.info('===============================')
195
+
196
+ if args.seed is not None:
197
+ logging.info(f'Using seed {args.seed}')
198
+ seed_everything(args.seed)
199
+
200
+ logger = LoggerCollection([
201
+ DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
202
+ WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
203
+ ])
204
+
205
+ datamodule = QM9DataModule(**vars(args))
206
+ model = SE3TransformerPooled(
207
+ fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
208
+ fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
209
+ fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
210
+ output_dim=1,
211
+ tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
212
+ **vars(args)
213
+ )
214
+ loss_fn = nn.L1Loss()
215
+
216
+ if args.benchmark:
217
+ logging.info('Running benchmark mode')
218
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
219
+ callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
220
+ else:
221
+ callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
222
+ QM9LRSchedulerCallback(logger, epochs=args.epochs)]
223
+
224
+ if is_distributed:
225
+ gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
226
+
227
+ print_parameters_count(model)
228
+ logger.log_hyperparams(vars(args))
229
+ increase_l2_fetch_granularity()
230
+ train(model,
231
+ loss_fn,
232
+ datamodule.train_dataloader(),
233
+ datamodule.val_dataloader(),
234
+ callbacks,
235
+ logger,
236
+ args)
237
+
238
+ logging.info('Training finished successfully')
env/SE3Transformer/build/lib/se3_transformer/runtime/utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import argparse
25
+ import ctypes
26
+ import logging
27
+ import os
28
+ import random
29
+ from functools import wraps
30
+ from typing import Union, List, Dict
31
+
32
+ import numpy as np
33
+ import torch
34
+ import torch.distributed as dist
35
+ from torch import Tensor
36
+
37
+
38
+ def aggregate_residual(feats1, feats2, method: str):
39
+ """ Add or concatenate two fiber features together. If degrees don't match, will use the ones of feats2. """
40
+ if method in ['add', 'sum']:
41
+ return {k: (v + feats1[k]) if k in feats1 else v for k, v in feats2.items()}
42
+ elif method in ['cat', 'concat']:
43
+ return {k: torch.cat([v, feats1[k]], dim=1) if k in feats1 else v for k, v in feats2.items()}
44
+ else:
45
+ raise ValueError('Method must be add/sum or cat/concat')
46
+
47
+
48
+ def degree_to_dim(degree: int) -> int:
49
+ return 2 * degree + 1
50
+
51
+
52
+ def unfuse_features(features: Tensor, degrees: List[int]) -> Dict[str, Tensor]:
53
+ return dict(zip(map(str, degrees), features.split([degree_to_dim(deg) for deg in degrees], dim=-1)))
54
+
55
+
56
+ def str2bool(v: Union[bool, str]) -> bool:
57
+ if isinstance(v, bool):
58
+ return v
59
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
60
+ return True
61
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
62
+ return False
63
+ else:
64
+ raise argparse.ArgumentTypeError('Boolean value expected.')
65
+
66
+
67
+ def to_cuda(x):
68
+ """ Try to convert a Tensor, a collection of Tensors or a DGLGraph to CUDA """
69
+ if isinstance(x, Tensor):
70
+ return x.cuda(non_blocking=True)
71
+ elif isinstance(x, tuple):
72
+ return (to_cuda(v) for v in x)
73
+ elif isinstance(x, list):
74
+ return [to_cuda(v) for v in x]
75
+ elif isinstance(x, dict):
76
+ return {k: to_cuda(v) for k, v in x.items()}
77
+ else:
78
+ # DGLGraph or other objects
79
+ return x.to(device=torch.cuda.current_device())
80
+
81
+
82
+ def get_local_rank() -> int:
83
+ return int(os.environ.get('LOCAL_RANK', 0))
84
+
85
+
86
+ def init_distributed() -> bool:
87
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
88
+ distributed = world_size > 1
89
+ if distributed:
90
+ backend = 'nccl' if torch.cuda.is_available() else 'gloo'
91
+ dist.init_process_group(backend=backend, init_method='env://')
92
+ if backend == 'nccl':
93
+ torch.cuda.set_device(get_local_rank())
94
+ else:
95
+ logging.warning('Running on CPU only!')
96
+ assert torch.distributed.is_initialized()
97
+ return distributed
98
+
99
+
100
+ def increase_l2_fetch_granularity():
101
+ # maximum fetch granularity of L2: 128 bytes
102
+ _libcudart = ctypes.CDLL('libcudart.so')
103
+ # set device limit on the current device
104
+ # cudaLimitMaxL2FetchGranularity = 0x05
105
+ pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
106
+ _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
107
+ _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
108
+ assert pValue.contents.value == 128
109
+
110
+
111
+ def seed_everything(seed):
112
+ seed = int(seed)
113
+ random.seed(seed)
114
+ np.random.seed(seed)
115
+ torch.manual_seed(seed)
116
+ torch.cuda.manual_seed_all(seed)
117
+
118
+
119
+ def rank_zero_only(fn):
120
+ @wraps(fn)
121
+ def wrapped_fn(*args, **kwargs):
122
+ if not dist.is_initialized() or dist.get_rank() == 0:
123
+ return fn(*args, **kwargs)
124
+
125
+ return wrapped_fn
126
+
127
+
128
+ def using_tensor_cores(amp: bool) -> bool:
129
+ major_cc, minor_cc = torch.cuda.get_device_capability()
130
+ return (amp and major_cc >= 7) or major_cc >= 8
env/SE3Transformer/build/lib/tests/__init__.py ADDED
File without changes
env/SE3Transformer/build/lib/tests/test_equivariance.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import torch
25
+
26
+ from se3_transformer.model import SE3Transformer
27
+ from se3_transformer.model.fiber import Fiber
28
+ from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
29
+
30
+ # Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
31
+ TOL = 1e-3
32
+ CHANNELS, NODES = 32, 512
33
+
34
+
35
+ def _get_outputs(model, R):
36
+ feats0 = torch.randn(NODES, CHANNELS, 1)
37
+ feats1 = torch.randn(NODES, CHANNELS, 3)
38
+
39
+ coords = torch.randn(NODES, 3)
40
+ graph = get_random_graph(NODES)
41
+ if torch.cuda.is_available():
42
+ feats0 = feats0.cuda()
43
+ feats1 = feats1.cuda()
44
+ R = R.cuda()
45
+ coords = coords.cuda()
46
+ graph = graph.to('cuda')
47
+ model.cuda()
48
+
49
+ graph1 = assign_relative_pos(graph, coords)
50
+ out1 = model(graph1, {'0': feats0, '1': feats1}, {})
51
+ graph2 = assign_relative_pos(graph, coords @ R)
52
+ out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
53
+
54
+ return out1, out2
55
+
56
+
57
+ def _get_model(**kwargs):
58
+ return SE3Transformer(
59
+ num_layers=4,
60
+ fiber_in=Fiber.create(2, CHANNELS),
61
+ fiber_hidden=Fiber.create(3, CHANNELS),
62
+ fiber_out=Fiber.create(2, CHANNELS),
63
+ fiber_edge=Fiber({}),
64
+ num_heads=8,
65
+ channels_div=2,
66
+ **kwargs
67
+ )
68
+
69
+
70
+ def test_equivariance():
71
+ model = _get_model()
72
+ R = rot(*torch.rand(3))
73
+ if torch.cuda.is_available():
74
+ R = R.cuda()
75
+ out1, out2 = _get_outputs(model, R)
76
+
77
+ assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
78
+ f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
79
+ assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
80
+ f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
81
+
82
+
83
+ def test_equivariance_pooled():
84
+ model = _get_model(pooling='avg', return_type=1)
85
+ R = rot(*torch.rand(3))
86
+ if torch.cuda.is_available():
87
+ R = R.cuda()
88
+ out1, out2 = _get_outputs(model, R)
89
+
90
+ assert torch.allclose(out2, (out1 @ R), atol=TOL), \
91
+ f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}'
92
+
93
+
94
+ def test_invariance_pooled():
95
+ model = _get_model(pooling='avg', return_type=0)
96
+ R = rot(*torch.rand(3))
97
+ if torch.cuda.is_available():
98
+ R = R.cuda()
99
+ out1, out2 = _get_outputs(model, R)
100
+
101
+ assert torch.allclose(out2, out1, atol=TOL), \
102
+ f'type-0 features should be invariant {get_max_diff(out1, out2)}'
env/SE3Transformer/build/lib/tests/utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import dgl
25
+ import torch
26
+
27
+
28
+ def get_random_graph(N, num_edges_factor=18):
29
+ graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
30
+ return graph
31
+
32
+
33
+ def assign_relative_pos(graph, coords):
34
+ src, dst = graph.edges()
35
+ graph.edata['rel_pos'] = coords[src] - coords[dst]
36
+ return graph
37
+
38
+
39
+ def get_max_diff(a, b):
40
+ return (a - b).abs().max().item()
41
+
42
+
43
+ def rot_z(gamma):
44
+ return torch.tensor([
45
+ [torch.cos(gamma), -torch.sin(gamma), 0],
46
+ [torch.sin(gamma), torch.cos(gamma), 0],
47
+ [0, 0, 1]
48
+ ], dtype=gamma.dtype)
49
+
50
+
51
+ def rot_y(beta):
52
+ return torch.tensor([
53
+ [torch.cos(beta), 0, torch.sin(beta)],
54
+ [0, 1, 0],
55
+ [-torch.sin(beta), 0, torch.cos(beta)]
56
+ ], dtype=beta.dtype)
57
+
58
+
59
+ def rot(alpha, beta, gamma):
60
+ return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)