Cpt-Nemo commited on
Commit
c0985d1
·
verified ·
1 Parent(s): 758f213

Update boltz_gradio.py

Browse files
Files changed (1) hide show
  1. boltz_gradio.py +73 -37
boltz_gradio.py CHANGED
@@ -9,6 +9,7 @@ import plotly.graph_objects as go
9
  from yaml import safe_dump, safe_load
10
  from rdkit import Chem, RDLogger
11
  from rdkit.Chem import AllChem, Descriptors
 
12
  from rdkit.Geometry import Point3D
13
  from rdkit.Chem.rdDetermineBonds import DetermineConnectivity
14
  from rdkit.Contrib.SA_Score import sascorer # type: ignore
@@ -56,10 +57,10 @@ property_functions = {'Molecular Weight' : Descriptors.MolWt,
56
  'Formal Charge' : lambda mol: sum([atom.GetFormalCharge() for atom in mol.GetAtoms()]),
57
  'Num. of Heavy Atoms' : Descriptors.HeavyAtomCount,
58
  'Num. of Atoms' : lambda mol: mol.GetNumAtoms(),
59
- 'Molar Refractivity' : Descriptors.MolMR,
60
- 'Quantitative Estimate of Drug-Likeness (QED)' : Descriptors.qed,
61
- 'Natural Product-likeness Score (NP)': partial(npscorer.scoreMol, fscore=fscore),
62
- 'Synthetic Accessibility Score (SA)': sascorer.calculateScore}
63
 
64
  file_extract_matching_map = {'Structure' : ['.cif', '.sdf', '_bust.csv'],
65
  'Confidence': ['confidence_'],
@@ -351,8 +352,8 @@ def __extract_cif_ca_coord(cif_f: str, get_weight: bool=True):
351
  if i in backbone_idx]
352
  bb_coords_conf = np.array(bb_coords_conf, float)
353
  conf = bb_coords_conf[:, -1]/100
354
- thres = 0.4
355
- return bb_coords_conf[:, :3], mmcif_dict, (np.maximum(conf-thres, 0.05) / (1-thres)) ** 2
356
  else:
357
  bb_coords = [[x, y, z] for i, (x, y, z) in enumerate(zip(mmcif_dict['_atom_site.Cartn_x'],
358
  mmcif_dict['_atom_site.Cartn_y'],
@@ -446,8 +447,10 @@ def execute_single_boltz(file_name: str, yaml_str: str,
446
 
447
  yield gr.update(value='Predicting...', interactive=False), ''
448
  full_output = ''
 
 
449
  curr_running_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
450
- text=True, encoding="utf-8")
451
  for line in iter(curr_running_process.stdout.readline, ''):
452
  if 'The loaded checkpoint was produced with' in line or\
453
  'You are using a CUDA device' in line: # Just skip these warnings
@@ -515,8 +518,10 @@ def execute_multi_boltz(all_files: list[str],
515
 
516
  yield gr.update(value='Predicting...', interactive=False), ''
517
  full_output = ''
 
 
518
  curr_running_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
519
- text=True, encoding="utf-8")
520
  for line in iter(curr_running_process.stdout.readline, ''):
521
  if 'The loaded checkpoint was produced with' in line or\
522
  'You are using a CUDA device' in line:
@@ -607,11 +612,14 @@ def execute_vhts_boltz(file_prefix: str, all_ligands: pd.DataFrame,
607
  f.write(safe_dump(yaml_template_dict))
608
 
609
  # execute on only a single file to retrieve msa, prevent colabfold server overload
 
610
  if idx == 0:
611
  yield gr.update(value='Predicting...', interactive=False), ''
612
  full_output = ''
 
 
613
  curr_running_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
614
- text=True, encoding="utf-8")
615
  for line in iter(curr_running_process.stdout.readline, ''):
616
  if 'The loaded checkpoint was produced with' in line or\
617
  'You are using a CUDA device' in line: # Just skip these warnings
@@ -639,7 +647,7 @@ def execute_vhts_boltz(file_prefix: str, all_ligands: pd.DataFrame,
639
  cmd[6] = str(devices) # replace the "devices" param back to user-defined value
640
 
641
  curr_running_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
642
- text=True, encoding="utf-8")
643
  for line in iter(curr_running_process.stdout.readline, ''):
644
  if 'The loaded checkpoint was produced with' in line or\
645
  'You are using a CUDA device' in line:
@@ -711,9 +719,11 @@ def _process_single_chem_file(chem_f: str):
711
  n = __check_smi_title_line(chem_f)
712
  mols = Chem.MultithreadedSmilesMolSupplier(chem_f, titleLine=n)
713
  names, smiles = [], []
 
714
  for mol in mols:
715
  if mol is None:
716
  continue
 
717
  if mol.HasProp('_Name'):
718
  name = mol.GetProp('_Name')
719
  else:
@@ -753,6 +763,7 @@ def _process_tabular_files(chem_f: list[str], name_col: str, chem_col: str, deli
753
  except:
754
  return [], []
755
  final_names, final_smiles = [], []
 
756
  for _, row in df.iterrows():
757
  name = row[name_col]
758
  chem_str = row[chem_col]
@@ -761,6 +772,7 @@ def _process_tabular_files(chem_f: list[str], name_col: str, chem_col: str, deli
761
  else:
762
  mol = Chem.MolFromSmiles(chem_str)
763
  if mol is not None:
 
764
  smi = Chem.MolToSmiles(mol)
765
  final_names.append(name)
766
  final_smiles.append(smi)
@@ -1695,6 +1707,13 @@ def draw_smiles_3d(smiles_str: str):
1695
  if isinstance(v, float):
1696
  v = round(v, 4)
1697
  data_dict['Value'].append(v)
 
 
 
 
 
 
 
1698
  yield get_mol_molstar_html(''), gr.update(value=pd.DataFrame(data_dict))
1699
  new_mol = rdkit_embed_with_timeout(mol, 60)
1700
  if new_mol is None:
@@ -2361,6 +2380,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as Interface:
2361
 
2362
  @gr.render(inputs=vhts_entity_number)
2363
  def vhts_append_new_entity(counts: int):
 
2364
  component_refs = []
2365
  for i in range(counts):
2366
  gr.Markdown(f'<span style="font-size:15px; font-weight:bold;">Entity {i+1}</span>', key=f'MK_{i}')
@@ -2394,11 +2414,16 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as Interface:
2394
  elem_classes='validation',
2395
  show_legend=True)
2396
  with gr.Column(key=f'Entity_{i}_sub3', scale=1):
2397
- cyclic_ckbox = gr.Checkbox(False, label='Cyclic', key=f'vhts_Cyclic_{i}')
 
 
 
2398
  modification_text = gr.Text(label='Modifications (Residue:CCD)',
2399
- placeholder='2:ALY,15:MSE', key=f'vhts_Mod_{i}')
 
 
2400
  component_refs.extend([entity_menu, chain_name_text, sequence_text,
2401
- cyclic_ckbox, modification_text])
2402
  entity_menu.change(change_sequence_label,
2403
  inputs=[entity_menu, sequence_text, cyclic_ckbox],
2404
  outputs=[sequence_text, highlight_text, cyclic_ckbox])
@@ -2407,17 +2432,19 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as Interface:
2407
  outputs=highlight_text)
2408
 
2409
  gr.HTML("<hr>")
2410
- chain_components = [comp for i, comp in enumerate(component_refs) if i % 5 == 1]
2411
- entity_components = [comp for i, comp in enumerate(component_refs) if i % 5 == 0]
2412
- for i, chain_input in enumerate(chain_components):
 
 
2413
  chain_input.submit(vhts_update_all_chains_dropdown,
2414
  inputs=chain_components,
2415
  outputs=[vhts_contact_1_dropdown, vhts_contact_2_dropdown,
2416
  vhts_target_chain_ids])
2417
- entity_components[i].change(vhts_update_all_chains_dropdown,
2418
- inputs=chain_components,
2419
- outputs=[vhts_contact_1_dropdown, vhts_contact_2_dropdown,
2420
- vhts_target_chain_ids])
2421
 
2422
  def write_yaml_func(binder, target, pocket_max_d, pocket_f, aff_binder,
2423
  cont_1_c, cont_1_r, cont_2_c, cont_2_r, contact_max_dist, contact_f,
@@ -2486,11 +2513,12 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as Interface:
2486
  data_dict.update({'templates': all_templates})
2487
 
2488
  existing_chains = []
 
2489
 
2490
- all_components += ['Ligand', binder, 'c1ccccc1', False, '']
2491
 
2492
- for i in range(0, len(all_components), 5):
2493
- entity, chain, seq, cyclic, mod = all_components[i:i+5]
2494
  seq = seq.strip()
2495
 
2496
  # set entity type
@@ -2504,36 +2532,36 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as Interface:
2504
  if len(chains) == 1:
2505
  id = chain.strip()
2506
  if id in existing_chains:
2507
- return f'Chain {id} of Entity {i//5+1} already existed!'
2508
  existing_chains.append(id)
2509
  else:
2510
  id = [c.strip() for c in chains]
2511
  for _i in id:
2512
  if id.count(_i) > 1:
2513
- return f'Duplicate chain found within Entity {i//5+1}!'
2514
  if _i in existing_chains:
2515
- return f'Chain {id} of Entity {i//5+1} already existed!'
2516
  existing_chains.extend(id)
2517
 
2518
  # set key of sequence ('sequence', 'ccd' or 'smiles')
2519
  if not seq:
2520
- return f'Entity {i//5+1} is empty!'
2521
  if entity == 'CCD':
2522
- seq = seq.upper()
2523
  seq_key = 'ccd'
2524
- if not re.fullmatch(r'(?:[A-Z0-9]{3}|[A-Z0-9]{5}|[A-Z]{2})', seq):
2525
- return f'Entity {i//5+1} is not a valid CCD ID!'
 
2526
  elif entity == 'Ligand':
2527
  seq_key = 'smiles'
2528
  if Chem.MolFromSmiles(seq) is None:
2529
- return f'Entity {i//5+1} is not a valid SMILES!'
2530
  else:
2531
  seq = seq.upper()
2532
  seq_key = 'sequence'
2533
  valid_strs = allow_char_dict[entity]
2534
  for char in seq:
2535
  if char not in valid_strs:
2536
- return f'Entity {i//5+1} is not a valid {entity}!'
2537
 
2538
  # set modification
2539
  if mod:
@@ -2541,7 +2569,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as Interface:
2541
  all_mods = mod.split(',')
2542
  for pos_ccd in all_mods:
2543
  if ':' not in pos_ccd:
2544
- return (f'Invalid modification for Entity {i//5+1}, please use ":" to '
2545
  f'separate residue and CCD!\n')
2546
  pos, ccd = pos_ccd.split(':')
2547
  modifications.append({'position': int(pos), 'ccd': ccd})
@@ -2550,13 +2578,21 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as Interface:
2550
 
2551
  if entity_type == 'ligand':
2552
  curr_dict = {entity_type: {'id' : id,
2553
- seq_key : seq,}
2554
- }
2555
  else:
2556
  curr_dict = {entity_type: {'id' : id,
2557
  seq_key : seq.upper(),
2558
- 'cyclic': cyclic}
2559
- }
 
 
 
 
 
 
 
 
 
2560
  if modifications is not None:
2561
  curr_dict[entity_type]['modifications'] = modifications
2562
 
 
9
  from yaml import safe_dump, safe_load
10
  from rdkit import Chem, RDLogger
11
  from rdkit.Chem import AllChem, Descriptors
12
+ from rdkit.Chem.SaltRemover import SaltRemover
13
  from rdkit.Geometry import Point3D
14
  from rdkit.Chem.rdDetermineBonds import DetermineConnectivity
15
  from rdkit.Contrib.SA_Score import sascorer # type: ignore
 
57
  'Formal Charge' : lambda mol: sum([atom.GetFormalCharge() for atom in mol.GetAtoms()]),
58
  'Num. of Heavy Atoms' : Descriptors.HeavyAtomCount,
59
  'Num. of Atoms' : lambda mol: mol.GetNumAtoms(),
60
+ 'Molar Refractivity' : Descriptors.MolMR}
61
+ property_functions_no_H = {'Quantitative Estimate of Drug-Likeness (QED)' : Descriptors.qed,
62
+ 'Natural Product-likeness Score (NP)': partial(npscorer.scoreMol, fscore=fscore),
63
+ 'Synthetic Accessibility Score (SA)': sascorer.calculateScore}
64
 
65
  file_extract_matching_map = {'Structure' : ['.cif', '.sdf', '_bust.csv'],
66
  'Confidence': ['confidence_'],
 
352
  if i in backbone_idx]
353
  bb_coords_conf = np.array(bb_coords_conf, float)
354
  conf = bb_coords_conf[:, -1]/100
355
+ thres = 0.5
356
+ return bb_coords_conf[:, :3], mmcif_dict, (np.maximum(conf-thres, 0.) / (1-thres)) ** 2
357
  else:
358
  bb_coords = [[x, y, z] for i, (x, y, z) in enumerate(zip(mmcif_dict['_atom_site.Cartn_x'],
359
  mmcif_dict['_atom_site.Cartn_y'],
 
447
 
448
  yield gr.update(value='Predicting...', interactive=False), ''
449
  full_output = ''
450
+ env = dict(os.environ)
451
+ env['NCCL_P2P_DISABLE'] = '1'
452
  curr_running_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
453
+ text=True, encoding="utf-8", env=env)
454
  for line in iter(curr_running_process.stdout.readline, ''):
455
  if 'The loaded checkpoint was produced with' in line or\
456
  'You are using a CUDA device' in line: # Just skip these warnings
 
518
 
519
  yield gr.update(value='Predicting...', interactive=False), ''
520
  full_output = ''
521
+ env = dict(os.environ)
522
+ env['NCCL_P2P_DISABLE'] = '1'
523
  curr_running_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
524
+ text=True, encoding="utf-8", env=env)
525
  for line in iter(curr_running_process.stdout.readline, ''):
526
  if 'The loaded checkpoint was produced with' in line or\
527
  'You are using a CUDA device' in line:
 
612
  f.write(safe_dump(yaml_template_dict))
613
 
614
  # execute on only a single file to retrieve msa, prevent colabfold server overload
615
+ # This should work for custom MSA too (update)
616
  if idx == 0:
617
  yield gr.update(value='Predicting...', interactive=False), ''
618
  full_output = ''
619
+ env = dict(os.environ)
620
+ env['NCCL_P2P_DISABLE'] = '1'
621
  curr_running_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
622
+ text=True, encoding="utf-8", env=env)
623
  for line in iter(curr_running_process.stdout.readline, ''):
624
  if 'The loaded checkpoint was produced with' in line or\
625
  'You are using a CUDA device' in line: # Just skip these warnings
 
647
  cmd[6] = str(devices) # replace the "devices" param back to user-defined value
648
 
649
  curr_running_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
650
+ text=True, encoding="utf-8", env=env)
651
  for line in iter(curr_running_process.stdout.readline, ''):
652
  if 'The loaded checkpoint was produced with' in line or\
653
  'You are using a CUDA device' in line:
 
719
  n = __check_smi_title_line(chem_f)
720
  mols = Chem.MultithreadedSmilesMolSupplier(chem_f, titleLine=n)
721
  names, smiles = [], []
722
+ remover = SaltRemover()
723
  for mol in mols:
724
  if mol is None:
725
  continue
726
+ mol = remover.StripMol(mol)
727
  if mol.HasProp('_Name'):
728
  name = mol.GetProp('_Name')
729
  else:
 
763
  except:
764
  return [], []
765
  final_names, final_smiles = [], []
766
+ remover = SaltRemover()
767
  for _, row in df.iterrows():
768
  name = row[name_col]
769
  chem_str = row[chem_col]
 
772
  else:
773
  mol = Chem.MolFromSmiles(chem_str)
774
  if mol is not None:
775
+ mol = remover.StripMol(mol)
776
  smi = Chem.MolToSmiles(mol)
777
  final_names.append(name)
778
  final_smiles.append(smi)
 
1707
  if isinstance(v, float):
1708
  v = round(v, 4)
1709
  data_dict['Value'].append(v)
1710
+ mol = Chem.RemoveHs(mol)
1711
+ data_dict['Property'].append(list(property_functions_no_H))
1712
+ for func in property_functions_no_H.values():
1713
+ v = func(mol)
1714
+ if isinstance(v, float):
1715
+ v = round(v, 4)
1716
+ data_dict['Value'].append(v)
1717
  yield get_mol_molstar_html(''), gr.update(value=pd.DataFrame(data_dict))
1718
  new_mol = rdkit_embed_with_timeout(mol, 60)
1719
  if new_mol is None:
 
2380
 
2381
  @gr.render(inputs=vhts_entity_number)
2382
  def vhts_append_new_entity(counts: int):
2383
+ component_cnt = 7
2384
  component_refs = []
2385
  for i in range(counts):
2386
  gr.Markdown(f'<span style="font-size:15px; font-weight:bold;">Entity {i+1}</span>', key=f'MK_{i}')
 
2414
  elem_classes='validation',
2415
  show_legend=True)
2416
  with gr.Column(key=f'Entity_{i}_sub3', scale=1):
2417
+ with gr.Row(key=f'Entity_{i}_sub3_group1_row1'):
2418
+ cyclic_ckbox = gr.Checkbox(False, label='Cyclic', min_width=50, key=f'Cyclic_{i}')
2419
+ msa_ckbox = gr.Checkbox(True, label='Use MSA', min_width=50, interactive=True,
2420
+ key=f'use_MSA_{i}')
2421
  modification_text = gr.Text(label='Modifications (Residue:CCD)',
2422
+ placeholder='2:ALY,15:MSE', key=f'Mod_{i}')
2423
+ msa_file = gr.File(label='MSA File', file_types=['.a3m', '.csv'], height=92,
2424
+ elem_classes='small-upload-style', key=f'msa_file_{i}')
2425
  component_refs.extend([entity_menu, chain_name_text, sequence_text,
2426
+ cyclic_ckbox, modification_text, msa_file, msa_ckbox])
2427
  entity_menu.change(change_sequence_label,
2428
  inputs=[entity_menu, sequence_text, cyclic_ckbox],
2429
  outputs=[sequence_text, highlight_text, cyclic_ckbox])
 
2432
  outputs=highlight_text)
2433
 
2434
  gr.HTML("<hr>")
2435
+ chain_components = [comp for i, comp in enumerate(component_refs) if i % component_cnt <= 1]
2436
+ entity_components = [comp for i, comp in enumerate(component_refs) if i % component_cnt == 0]
2437
+ for i in range(0, len(chain_components), 2):
2438
+ chain_input = chain_components[i+1]
2439
+ entity_menu = entity_components[i//2]
2440
  chain_input.submit(vhts_update_all_chains_dropdown,
2441
  inputs=chain_components,
2442
  outputs=[vhts_contact_1_dropdown, vhts_contact_2_dropdown,
2443
  vhts_target_chain_ids])
2444
+ entity_menu.change(vhts_update_all_chains_dropdown,
2445
+ inputs=chain_components,
2446
+ outputs=[vhts_contact_1_dropdown, vhts_contact_2_dropdown,
2447
+ vhts_target_chain_ids])
2448
 
2449
  def write_yaml_func(binder, target, pocket_max_d, pocket_f, aff_binder,
2450
  cont_1_c, cont_1_r, cont_2_c, cont_2_r, contact_max_dist, contact_f,
 
2513
  data_dict.update({'templates': all_templates})
2514
 
2515
  existing_chains = []
2516
+ msa_rng_name = uuid.uuid4().hex[:8]
2517
 
2518
+ all_components += ['Ligand', binder, 'c1ccccc1', False, '', '', False]
2519
 
2520
+ for i in range(0, len(all_components), component_cnt):
2521
+ entity, chain, seq, cyclic, mod, msa_pth, use_msa = all_components[i:i+component_cnt]
2522
  seq = seq.strip()
2523
 
2524
  # set entity type
 
2532
  if len(chains) == 1:
2533
  id = chain.strip()
2534
  if id in existing_chains:
2535
+ return f'Chain {id} of Entity {i//component_cnt+1} already existed!'
2536
  existing_chains.append(id)
2537
  else:
2538
  id = [c.strip() for c in chains]
2539
  for _i in id:
2540
  if id.count(_i) > 1:
2541
+ return f'Duplicate chain found within Entity {i//component_cnt+1}!'
2542
  if _i in existing_chains:
2543
+ return f'Chain {id} of Entity {i//component_cnt+1} already existed!'
2544
  existing_chains.extend(id)
2545
 
2546
  # set key of sequence ('sequence', 'ccd' or 'smiles')
2547
  if not seq:
2548
+ return f'Entity {i//component_cnt+1} is empty!'
2549
  if entity == 'CCD':
 
2550
  seq_key = 'ccd'
2551
+ seq = seq.upper()
2552
+ if not re.fullmatch(r'(?:[A-Z0-9]{3}|[A-Z0-9]{5})|[A-Z]{2}', seq):
2553
+ return f'Entity {i//component_cnt+1} is not a valid CCD ID!'
2554
  elif entity == 'Ligand':
2555
  seq_key = 'smiles'
2556
  if Chem.MolFromSmiles(seq) is None:
2557
+ return f'Entity {i//component_cnt+1} is not a valid SMILES!'
2558
  else:
2559
  seq = seq.upper()
2560
  seq_key = 'sequence'
2561
  valid_strs = allow_char_dict[entity]
2562
  for char in seq:
2563
  if char not in valid_strs:
2564
+ return f'Entity {i//component_cnt+1} is not a valid {entity}!'
2565
 
2566
  # set modification
2567
  if mod:
 
2569
  all_mods = mod.split(',')
2570
  for pos_ccd in all_mods:
2571
  if ':' not in pos_ccd:
2572
+ return (f'Invalid modification for Entity {i//component_cnt+1}, please use ":" to '
2573
  f'separate residue and CCD!\n')
2574
  pos, ccd = pos_ccd.split(':')
2575
  modifications.append({'position': int(pos), 'ccd': ccd})
 
2578
 
2579
  if entity_type == 'ligand':
2580
  curr_dict = {entity_type: {'id' : id,
2581
+ seq_key : seq,}}
 
2582
  else:
2583
  curr_dict = {entity_type: {'id' : id,
2584
  seq_key : seq.upper(),
2585
+ 'cyclic': cyclic}}
2586
+
2587
+ # Check for MSA
2588
+ if entity_type == 'protein':
2589
+ if msa_pth and use_msa:
2590
+ target_msa = os.path.join(msa_dir, msa_rng_name, os.path.basename(msa_pth))
2591
+ os.makedirs(os.path.dirname(target_msa), exist_ok=True)
2592
+ shutil.copy(msa_pth, target_msa)
2593
+ curr_dict[entity_type]['msa'] = target_msa
2594
+ elif not use_msa:
2595
+ curr_dict[entity_type]['msa'] = 'empty'
2596
  if modifications is not None:
2597
  curr_dict[entity_type]['modifications'] = modifications
2598