Spaces:
Runtime error
Runtime error
# pyright: reportMissingImports=false | |
import os | |
import time | |
import pyrosetta | |
from pyrosetta.rosetta.protocols.relax import FastRelax | |
from pyrosetta.rosetta.core.pack.task import TaskFactory | |
from pyrosetta.rosetta.core.pack.task import operation | |
from pyrosetta.rosetta.core.select import residue_selector as selections | |
from pyrosetta.rosetta.core.select.movemap import MoveMapFactory, move_map_action | |
pyrosetta.init(' '.join([ | |
'-mute', 'all', | |
'-use_input_sc', | |
'-ignore_unrecognized_res', | |
'-ignore_zero_occupancy', 'false', | |
'-load_PDB_components', 'false', | |
'-relax:default_repeats', '2', | |
'-no_fconfig', | |
])) | |
from diffab.tools.relax.base import RelaxTask | |
def current_milli_time(): | |
return round(time.time() * 1000) | |
def parse_residue_position(p): | |
icode = None | |
if not p[-1].isnumeric(): # Has ICODE | |
icode = p[-1] | |
for i, c in enumerate(p): | |
if c.isnumeric(): | |
break | |
chain = p[:i] | |
resseq = int(p[i:]) | |
if icode is not None: | |
return chain, resseq, icode | |
else: | |
return chain, resseq | |
def get_scorefxn(scorefxn_name:str): | |
""" | |
Gets the scorefxn with appropriate corrections. | |
Taken from: https://gist.github.com/matteoferla/b33585f3aeab58b8424581279e032550 | |
""" | |
import pyrosetta | |
corrections = { | |
'beta_july15': False, | |
'beta_nov16': False, | |
'gen_potential': False, | |
'restore_talaris_behavior': False, | |
} | |
if 'beta_july15' in scorefxn_name or 'beta_nov15' in scorefxn_name: | |
# beta_july15 is ref2015 | |
corrections['beta_july15'] = True | |
elif 'beta_nov16' in scorefxn_name: | |
corrections['beta_nov16'] = True | |
elif 'genpot' in scorefxn_name: | |
corrections['gen_potential'] = True | |
pyrosetta.rosetta.basic.options.set_boolean_option('corrections:beta_july15', True) | |
elif 'talaris' in scorefxn_name: #2013 and 2014 | |
corrections['restore_talaris_behavior'] = True | |
else: | |
pass | |
for corr, value in corrections.items(): | |
pyrosetta.rosetta.basic.options.set_boolean_option(f'corrections:{corr}', value) | |
return pyrosetta.create_score_function(scorefxn_name) | |
class RelaxRegion(object): | |
def __init__(self, scorefxn='ref2015', max_iter=1000, subset='nbrs', move_bb=True): | |
super().__init__() | |
self.scorefxn = get_scorefxn(scorefxn) | |
self.fast_relax = FastRelax() | |
self.fast_relax.set_scorefxn(self.scorefxn) | |
self.fast_relax.max_iter(max_iter) | |
assert subset in ('all', 'target', 'nbrs') | |
self.subset = subset | |
self.move_bb = move_bb | |
def __call__(self, pdb_path, flexible_residue_first, flexible_residue_last): | |
pose = pyrosetta.pose_from_pdb(pdb_path) | |
start_t = current_milli_time() | |
original_pose = pose.clone() | |
tf = TaskFactory() | |
tf.push_back(operation.InitializeFromCommandline()) | |
tf.push_back(operation.RestrictToRepacking()) # Only allow residues to repack. No design at any position. | |
# Create selector for the region to be relaxed | |
# Turn off design and repacking on irrelevant positions | |
if flexible_residue_first[-1] == ' ': | |
flexible_residue_first = flexible_residue_first[:-1] | |
if flexible_residue_last[-1] == ' ': | |
flexible_residue_last = flexible_residue_last[:-1] | |
if self.subset != 'all': | |
gen_selector = selections.ResidueIndexSelector() | |
gen_selector.set_index_range( | |
pose.pdb_info().pdb2pose(*flexible_residue_first), | |
pose.pdb_info().pdb2pose(*flexible_residue_last), | |
) | |
nbr_selector = selections.NeighborhoodResidueSelector() | |
nbr_selector.set_focus_selector(gen_selector) | |
nbr_selector.set_include_focus_in_subset(True) | |
if self.subset == 'nbrs': | |
subset_selector = nbr_selector | |
elif self.subset == 'target': | |
subset_selector = gen_selector | |
prevent_repacking_rlt = operation.PreventRepackingRLT() | |
prevent_subset_repacking = operation.OperateOnResidueSubset( | |
prevent_repacking_rlt, | |
subset_selector, | |
flip_subset=True, | |
) | |
tf.push_back(prevent_subset_repacking) | |
scorefxn = self.scorefxn | |
fr = self.fast_relax | |
pose = original_pose.clone() | |
pos_list = pyrosetta.rosetta.utility.vector1_unsigned_long() | |
for pos in range(pose.pdb_info().pdb2pose(*flexible_residue_first), pose.pdb_info().pdb2pose(*flexible_residue_last)+1): | |
pos_list.append(pos) | |
# basic_idealize(pose, pos_list, scorefxn, fast=True) | |
mmf = MoveMapFactory() | |
if self.move_bb: | |
mmf.add_bb_action(move_map_action.mm_enable, gen_selector) | |
mmf.add_chi_action(move_map_action.mm_enable, subset_selector) | |
mm = mmf.create_movemap_from_pose(pose) | |
fr.set_movemap(mm) | |
fr.set_task_factory(tf) | |
fr.apply(pose) | |
e_before = scorefxn(original_pose) | |
e_relax = scorefxn(pose) | |
# print('\n\n[Finished in %.2f secs]' % ((current_milli_time() - start_t) / 1000)) | |
# print(' > Energy (before): %.4f' % scorefxn(original_pose)) | |
# print(' > Energy (optimized): %.4f' % scorefxn(pose)) | |
return pose, e_before, e_relax | |
def run_pyrosetta(task: RelaxTask): | |
if not task.can_proceed() : | |
return task | |
if task.update_if_finished('rosetta'): | |
return task | |
minimizer = RelaxRegion() | |
pose_min, _, _ = minimizer( | |
pdb_path = task.current_path, | |
flexible_residue_first = task.flexible_residue_first, | |
flexible_residue_last = task.flexible_residue_last, | |
) | |
out_path = task.set_current_path_tag('rosetta') | |
pose_min.dump_pdb(out_path) | |
task.mark_success() | |
return task | |
def run_pyrosetta_fixbb(task: RelaxTask): | |
if not task.can_proceed() : | |
return task | |
if task.update_if_finished('fixbb'): | |
return task | |
minimizer = RelaxRegion(move_bb=False) | |
pose_min, _, _ = minimizer( | |
pdb_path = task.current_path, | |
flexible_residue_first = task.flexible_residue_first, | |
flexible_residue_last = task.flexible_residue_last, | |
) | |
out_path = task.set_current_path_tag('fixbb') | |
pose_min.dump_pdb(out_path) | |
task.mark_success() | |
return task | |