Spaces:
Runtime error
Runtime error
File size: 2,100 Bytes
753e275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import os
import argparse
import ray
import shelve
import time
import pandas as pd
from typing import Mapping
from tools.eval.base import EvalTask, TaskScanner
from tools.eval.similarity import eval_similarity
from tools.eval.energy import eval_interface_energy
@ray.remote(num_cpus=1)
def evaluate(task, args):
funcs = []
funcs.append(eval_similarity)
if not args.no_energy:
funcs.append(eval_interface_energy)
for f in funcs:
task = f(task)
return task
def dump_db(db: Mapping[str, EvalTask], path):
table = []
for task in db.values():
if 'abopt' in path and task.scores['seqid'] >= 100.0:
# In abopt (Antibody Optimization) mode, ignore sequences identical to the wild-type
continue
table.append(task.to_report_dict())
table = pd.DataFrame(table)
table.to_csv(path, index=False, float_format='%.6f')
return table
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, default='./results')
parser.add_argument('--pfx', type=str, default='rosetta')
parser.add_argument('--no_energy', action='store_true', default=False)
args = parser.parse_args()
ray.init()
db_path = os.path.join(args.root, 'evaluation_db')
with shelve.open(db_path) as db:
scanner = TaskScanner(root=args.root, postfix=args.pfx, db=db)
while True:
tasks = scanner.scan()
futures = [evaluate.remote(t, args) for t in tasks]
if len(futures) > 0:
print(f'Submitted {len(futures)} tasks.')
while len(futures) > 0:
done_ids, futures = ray.wait(futures, num_returns=1)
for done_id in done_ids:
done_task = ray.get(done_id)
done_task.save_to_db(db)
print(f'Remaining {len(futures)}. Finished {done_task.in_path}')
db.sync()
dump_db(db, os.path.join(args.root, 'summary.csv'))
time.sleep(1.0)
if __name__ == '__main__':
main()
|