|
import unittest |
|
import subprocess |
|
import glob |
|
import datetime |
|
import os |
|
import torch |
|
from shutil import copyfile |
|
from rfdiffusion.inference import utils as iu |
|
from rfdiffusion.util import calc_rmsd |
|
import sys, json |
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
class TestSubmissionCommands(unittest.TestCase): |
|
""" |
|
Test harness for checking that commands in the examples folder, |
|
when run in deterministic mode, produce the same output as the |
|
reference outputs. |
|
Requirements: |
|
- example command must be written on a single line |
|
- outputs must be written to example_outputs folder |
|
- needs to be run on the same hardware as the reference outputs (A100 GPU) |
|
For speed, we only run the first 2 steps of diffusion, and set inference.num_designs=1 |
|
This means that outputs DO NOT look like proteins, but we can still check that the |
|
outputs are the same as the reference outputs. |
|
""" |
|
|
|
def setUp(self): |
|
""" |
|
Grabs files from the examples folder |
|
""" |
|
submissions = glob.glob(f"{script_dir}/../examples/*.sh") |
|
|
|
now = datetime.datetime.now() |
|
now = now.strftime("%Y_%m_%d_%H_%M_%S") |
|
self.out_f = f"{script_dir}/tests_{now}" |
|
os.mkdir(self.out_f) |
|
|
|
|
|
exclude_dirs = ["outputs", "example_outputs"] |
|
for filename in os.listdir(f"{script_dir}/../examples"): |
|
if filename not in exclude_dirs and not os.path.islink(os.path.join(script_dir, filename)) and os.path.isdir(os.path.join(f'{script_dir}/../examples', filename)): |
|
os.symlink(os.path.join(f'{script_dir}/../examples', filename), os.path.join(script_dir, filename)) |
|
|
|
for submission in submissions: |
|
self._write_command(submission, self.out_f) |
|
|
|
print(f"Running commands in {self.out_f}, two steps of diffusion, deterministic=True") |
|
|
|
self.results = {} |
|
|
|
for bash_file in sorted( glob.glob(f"{self.out_f}/*.sh"), reverse=False): |
|
test_name = os.path.basename(bash_file)[:-len('.sh')] |
|
res, output = execute(f"Running {test_name}", f'bash {bash_file}', return_='tuple', add_message_and_command_line_to_output=True) |
|
|
|
self.results[test_name] = dict( |
|
state = 'failed' if res else 'passed', |
|
log = output, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_commands(self): |
|
""" |
|
Runs all the commands in the test_f folder |
|
""" |
|
reference=f'{script_dir}/reference_outputs' |
|
os.makedirs(reference, exist_ok=True) |
|
test_files=glob.glob(f"{self.out_f}/example_outputs/*pdb") |
|
print(f'{self.out_f=} {test_files=}') |
|
|
|
|
|
|
|
|
|
result = self.defaultTestResult() |
|
for test_file in test_files: |
|
with self.subTest(test_file=test_file): |
|
test_pdb=iu.parse_pdb(test_file) |
|
if not os.path.exists(f"{reference}/{os.path.basename(test_file)}"): |
|
copyfile(test_file, f"{reference}/{os.path.basename(test_file)}") |
|
print(f"Created reference file {reference}/{os.path.basename(test_file)}") |
|
else: |
|
ref_pdb=iu.parse_pdb(f"{reference}/{os.path.basename(test_file)}") |
|
rmsd=calc_rmsd(test_pdb['xyz'][:,:3].reshape(-1,3), ref_pdb['xyz'][:,:3].reshape(-1,3))[0] |
|
try: |
|
self.assertAlmostEqual(rmsd, 0, 2) |
|
result.addSuccess(self) |
|
print(f"Subtest {test_file} passed") |
|
|
|
state = 'passed' |
|
log = f'Subtest {test_file} passed' |
|
|
|
except AssertionError as e: |
|
result.addFailure(self, e) |
|
print(f"Subtest {test_file} failed") |
|
|
|
state = 'failed' |
|
log = f'Subtest {test_file} failed:\n{e!r}' |
|
|
|
self.results[ 'pdb-diff.' + test_file.rpartition('/')[-1] ] = dict(state = state, log = log) |
|
|
|
with open('.results.json', 'w') as f: json.dump(self.results, f, sort_keys=True, indent=2) |
|
|
|
self.assertTrue(result.wasSuccessful(), "One or more subtests failed") |
|
|
|
|
|
def _write_command(self, bash_file, test_f) -> None: |
|
""" |
|
Takes a bash file from the examples folder, and writes |
|
a version of it to the test_f folder. |
|
It appends to the python command the following arguments: |
|
inference.deterministic=True |
|
if partial_T is in the command, it grabs partial T and sets: |
|
inference.final_step=partial_T-2 |
|
else: |
|
inference.final_step=48 |
|
""" |
|
out_lines=[] |
|
with open(bash_file, "r") as f: |
|
lines = f.readlines() |
|
for line in lines: |
|
if not (line.startswith("python") or line.startswith("../")): |
|
out_lines.append(line) |
|
else: |
|
command = line.strip() |
|
if not command.startswith("python"): |
|
command = f'python {command}' |
|
|
|
if "partial_T" in command: |
|
final_step = int(command.split("partial_T=")[1].split(" ")[0]) - 2 |
|
else: |
|
final_step = 48 |
|
|
|
output_command = f"{command} inference.deterministic=True inference.final_step={final_step}" |
|
|
|
if "inference.num_designs=" in output_command: |
|
output_command = f'{output_command.split("inference.num_designs=")[0]}inference.num_designs=1 {" ".join(output_command.split("inference.num_designs=")[1].split(" ")[1:])}' |
|
else: |
|
output_command = f'{output_command} inference.num_designs=1' |
|
|
|
output_command = f'{output_command.split("example_outputs")[0]}{self.out_f}/example_outputs{output_command.split("example_outputs")[1]}' |
|
|
|
|
|
|
|
with open(f"{test_f}/{os.path.basename(bash_file)}", "w") as f: |
|
for line in out_lines: |
|
f.write(line) |
|
f.write(output_command) |
|
|
|
|
|
|
|
def execute_through_pty(command_line): |
|
import pty, select |
|
|
|
if sys.platform == "darwin": |
|
|
|
master, slave = pty.openpty() |
|
p = subprocess.Popen(command_line, shell=True, stdout=slave, stdin=slave, |
|
stderr=subprocess.STDOUT, close_fds=True) |
|
|
|
buffer = [] |
|
while True: |
|
try: |
|
if select.select([master], [], [], 0.2)[0]: |
|
data = os.read(master, 1 << 22) |
|
if data: buffer.append(data) |
|
|
|
elif (p.poll() is not None) and (not select.select([master], [], [], 0.2)[0] ): break |
|
|
|
except OSError: break |
|
|
|
output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace') |
|
|
|
os.close(master) |
|
os.close(slave) |
|
|
|
p.wait() |
|
exit_code = p.returncode |
|
|
|
''' |
|
buffer = [] |
|
while True: |
|
if select.select([master], [], [], 0.2)[0]: # has something to read |
|
data = os.read(master, 1 << 22) |
|
if data: buffer.append(data) |
|
# else: break # # EOF - well, technically process _should_ be finished here... |
|
|
|
# elif time.sleep(1) or (p.poll() is not None): # process is finished (sleep here is intentional to trigger race condition, see solution for this on the next few lines) |
|
# assert not select.select([master], [], [], 0.2)[0] # should be nothing left to read... |
|
# break |
|
|
|
elif (p.poll() is not None) and (not select.select([master], [], [], 0.2)[0] ): break # process is finished and output buffer if fully read |
|
|
|
assert not select.select([master], [], [], 0.2)[0] # should be nothing left to read... |
|
|
|
os.close(slave) |
|
os.close(master) |
|
|
|
output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace') |
|
exit_code = p.returncode |
|
''' |
|
|
|
else: |
|
|
|
master, slave = pty.openpty() |
|
p = subprocess.Popen(command_line, shell=True, stdout=slave, stdin=slave, |
|
stderr=subprocess.STDOUT, close_fds=True) |
|
|
|
os.close(slave) |
|
|
|
buffer = [] |
|
while True: |
|
try: |
|
data = os.read(master, 1 << 22) |
|
if data: buffer.append(data) |
|
except OSError: break |
|
|
|
output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace') |
|
|
|
os.close(master) |
|
|
|
p.wait() |
|
exit_code = p.returncode |
|
|
|
return exit_code, output |
|
|
|
|
|
|
|
def execute(message, command_line, return_='status', until_successes=False, terminate_on_failure=True, silent=False, silence_output=False, silence_output_on_errors=False, add_message_and_command_line_to_output=False): |
|
if not silent: print(message); print(command_line); sys.stdout.flush(); |
|
while True: |
|
|
|
|
|
|
|
exit_code, output = execute_through_pty(command_line) |
|
|
|
if (exit_code and not silence_output_on_errors) or not (silent or silence_output): print(output); sys.stdout.flush(); |
|
|
|
if exit_code and until_successes: pass |
|
else: break |
|
|
|
print( "Error while executing {}: {}\n".format(message, output) ) |
|
print("Sleeping 60s... then I will retry...") |
|
sys.stdout.flush(); |
|
time.sleep(60) |
|
|
|
if add_message_and_command_line_to_output: output = message + '\nCommand line: ' + command_line + '\n' + output |
|
|
|
if return_ == 'tuple' or return_ == tuple: return(exit_code, output) |
|
|
|
if exit_code and terminate_on_failure: |
|
print("\nEncounter error while executing: " + command_line) |
|
if return_==True: return True |
|
else: |
|
print('\nEncounter error while executing: ' + command_line + '\n' + output); |
|
raise BenchmarkError('\nEncounter error while executing: ' + command_line + '\n' + output) |
|
|
|
if return_ == 'output': return output |
|
else: return exit_code |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|