DiffAb / diffab /tools /relax /pyrosetta_relaxer.py
luost26's picture
Update
753e275
# pyright: reportMissingImports=false
import os
import time
import pyrosetta
from pyrosetta.rosetta.protocols.relax import FastRelax
from pyrosetta.rosetta.core.pack.task import TaskFactory
from pyrosetta.rosetta.core.pack.task import operation
from pyrosetta.rosetta.core.select import residue_selector as selections
from pyrosetta.rosetta.core.select.movemap import MoveMapFactory, move_map_action
pyrosetta.init(' '.join([
'-mute', 'all',
'-use_input_sc',
'-ignore_unrecognized_res',
'-ignore_zero_occupancy', 'false',
'-load_PDB_components', 'false',
'-relax:default_repeats', '2',
'-no_fconfig',
]))
from diffab.tools.relax.base import RelaxTask
def current_milli_time():
return round(time.time() * 1000)
def parse_residue_position(p):
icode = None
if not p[-1].isnumeric(): # Has ICODE
icode = p[-1]
for i, c in enumerate(p):
if c.isnumeric():
break
chain = p[:i]
resseq = int(p[i:])
if icode is not None:
return chain, resseq, icode
else:
return chain, resseq
def get_scorefxn(scorefxn_name:str):
"""
Gets the scorefxn with appropriate corrections.
Taken from: https://gist.github.com/matteoferla/b33585f3aeab58b8424581279e032550
"""
import pyrosetta
corrections = {
'beta_july15': False,
'beta_nov16': False,
'gen_potential': False,
'restore_talaris_behavior': False,
}
if 'beta_july15' in scorefxn_name or 'beta_nov15' in scorefxn_name:
# beta_july15 is ref2015
corrections['beta_july15'] = True
elif 'beta_nov16' in scorefxn_name:
corrections['beta_nov16'] = True
elif 'genpot' in scorefxn_name:
corrections['gen_potential'] = True
pyrosetta.rosetta.basic.options.set_boolean_option('corrections:beta_july15', True)
elif 'talaris' in scorefxn_name: #2013 and 2014
corrections['restore_talaris_behavior'] = True
else:
pass
for corr, value in corrections.items():
pyrosetta.rosetta.basic.options.set_boolean_option(f'corrections:{corr}', value)
return pyrosetta.create_score_function(scorefxn_name)
class RelaxRegion(object):
def __init__(self, scorefxn='ref2015', max_iter=1000, subset='nbrs', move_bb=True):
super().__init__()
self.scorefxn = get_scorefxn(scorefxn)
self.fast_relax = FastRelax()
self.fast_relax.set_scorefxn(self.scorefxn)
self.fast_relax.max_iter(max_iter)
assert subset in ('all', 'target', 'nbrs')
self.subset = subset
self.move_bb = move_bb
def __call__(self, pdb_path, flexible_residue_first, flexible_residue_last):
pose = pyrosetta.pose_from_pdb(pdb_path)
start_t = current_milli_time()
original_pose = pose.clone()
tf = TaskFactory()
tf.push_back(operation.InitializeFromCommandline())
tf.push_back(operation.RestrictToRepacking()) # Only allow residues to repack. No design at any position.
# Create selector for the region to be relaxed
# Turn off design and repacking on irrelevant positions
if flexible_residue_first[-1] == ' ':
flexible_residue_first = flexible_residue_first[:-1]
if flexible_residue_last[-1] == ' ':
flexible_residue_last = flexible_residue_last[:-1]
if self.subset != 'all':
gen_selector = selections.ResidueIndexSelector()
gen_selector.set_index_range(
pose.pdb_info().pdb2pose(*flexible_residue_first),
pose.pdb_info().pdb2pose(*flexible_residue_last),
)
nbr_selector = selections.NeighborhoodResidueSelector()
nbr_selector.set_focus_selector(gen_selector)
nbr_selector.set_include_focus_in_subset(True)
if self.subset == 'nbrs':
subset_selector = nbr_selector
elif self.subset == 'target':
subset_selector = gen_selector
prevent_repacking_rlt = operation.PreventRepackingRLT()
prevent_subset_repacking = operation.OperateOnResidueSubset(
prevent_repacking_rlt,
subset_selector,
flip_subset=True,
)
tf.push_back(prevent_subset_repacking)
scorefxn = self.scorefxn
fr = self.fast_relax
pose = original_pose.clone()
pos_list = pyrosetta.rosetta.utility.vector1_unsigned_long()
for pos in range(pose.pdb_info().pdb2pose(*flexible_residue_first), pose.pdb_info().pdb2pose(*flexible_residue_last)+1):
pos_list.append(pos)
# basic_idealize(pose, pos_list, scorefxn, fast=True)
mmf = MoveMapFactory()
if self.move_bb:
mmf.add_bb_action(move_map_action.mm_enable, gen_selector)
mmf.add_chi_action(move_map_action.mm_enable, subset_selector)
mm = mmf.create_movemap_from_pose(pose)
fr.set_movemap(mm)
fr.set_task_factory(tf)
fr.apply(pose)
e_before = scorefxn(original_pose)
e_relax = scorefxn(pose)
# print('\n\n[Finished in %.2f secs]' % ((current_milli_time() - start_t) / 1000))
# print(' > Energy (before): %.4f' % scorefxn(original_pose))
# print(' > Energy (optimized): %.4f' % scorefxn(pose))
return pose, e_before, e_relax
def run_pyrosetta(task: RelaxTask):
if not task.can_proceed() :
return task
if task.update_if_finished('rosetta'):
return task
minimizer = RelaxRegion()
pose_min, _, _ = minimizer(
pdb_path = task.current_path,
flexible_residue_first = task.flexible_residue_first,
flexible_residue_last = task.flexible_residue_last,
)
out_path = task.set_current_path_tag('rosetta')
pose_min.dump_pdb(out_path)
task.mark_success()
return task
def run_pyrosetta_fixbb(task: RelaxTask):
if not task.can_proceed() :
return task
if task.update_if_finished('fixbb'):
return task
minimizer = RelaxRegion(move_bb=False)
pose_min, _, _ = minimizer(
pdb_path = task.current_path,
flexible_residue_first = task.flexible_residue_first,
flexible_residue_last = task.flexible_residue_last,
)
out_path = task.set_current_path_tag('fixbb')
pose_min.dump_pdb(out_path)
task.mark_success()
return task