aksell commited on
Commit
7c0cc4e
·
1 Parent(s): 01ba5a1

Fix residue indexing error

Browse files

Now all calculations use zero indexed residues.
Previously we'd get the wrong index due to one indexing
in the biopython default chain[index] api, but also due
to PDB files having different initial indexes for the first residue.

hexviz/attention.py CHANGED
@@ -5,6 +5,7 @@ import requests
5
  import streamlit as st
6
  import torch
7
  from Bio.PDB import PDBParser, Polypeptide, Structure
 
8
 
9
  from hexviz.ec_number import ECNumber
10
  from hexviz.models import (
@@ -60,16 +61,15 @@ def get_chains(structure: Structure) -> list[str]:
60
  return chains
61
 
62
 
63
- def get_sequence(chain) -> str:
64
  """
65
- Get sequence from a chain
66
 
67
  Residues not in the standard 20 amino acids are replaced with X
68
  """
69
- residues = [residue.get_resname() for residue in chain.get_residues()]
70
- # TODO ask if using protein_letters_3to1_extended makes sense
71
  residues_single_letter = map(
72
- lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), residues
73
  )
74
 
75
  return "".join(list(residues_single_letter))
@@ -243,11 +243,20 @@ def get_attention_pairs(
243
  top_n: int = 2,
244
  ec_numbers: list[list[ECNumber]] | None = None,
245
  ):
 
 
 
246
  structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
 
247
  if chain_ids:
248
  chains = [ch for ch in structure.get_chains() if ch.id in chain_ids]
249
  else:
250
  chains = list(structure.get_chains())
 
 
 
 
 
251
 
252
  attention_pairs = []
253
  top_residues = []
@@ -257,7 +266,7 @@ def get_attention_pairs(
257
 
258
  for i, chain in enumerate(chains):
259
  ec_number = ec_numbers[i] if ec_numbers else None
260
- sequence = get_sequence(chain)
261
  attention = get_attention(
262
  sequence=sequence, model_type=model_type, ec_number=ec_number
263
  )
@@ -270,7 +279,6 @@ def get_attention_pairs(
270
  for attn_value, res_1, res_2 in attention_unidirectional:
271
  try:
272
  if not ec_number:
273
- # Should you add 1 here? Arent chains 1 indexed and res indexeds 0 indexed
274
  coord_1 = chain[res_1]["CA"].coord.tolist()
275
  coord_2 = chain[res_2]["CA"].coord.tolist()
276
  else:
@@ -303,6 +311,6 @@ def get_attention_pairs(
303
 
304
  for res, attn_sum in top_n_residues:
305
  coord = chain[res]["CA"].coord.tolist()
306
- top_residues.append((attn_sum, coord, chain.id, res))
307
 
308
  return attention_pairs, top_residues
 
5
  import streamlit as st
6
  import torch
7
  from Bio.PDB import PDBParser, Polypeptide, Structure
8
+ from Bio.PDB.Residue import Residue
9
 
10
  from hexviz.ec_number import ECNumber
11
  from hexviz.models import (
 
61
  return chains
62
 
63
 
64
+ def res_to_1letter(residues: list[Residue]) -> str:
65
  """
66
+ Get single letter sequence from a list or Residues
67
 
68
  Residues not in the standard 20 amino acids are replaced with X
69
  """
70
+ res_names = [residue.get_resname() for residue in residues]
 
71
  residues_single_letter = map(
72
+ lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), res_names
73
  )
74
 
75
  return "".join(list(residues_single_letter))
 
243
  top_n: int = 2,
244
  ec_numbers: list[list[ECNumber]] | None = None,
245
  ):
246
+ """
247
+ Note: All residue indexes returned are 0 indexed
248
+ """
249
  structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
250
+
251
  if chain_ids:
252
  chains = [ch for ch in structure.get_chains() if ch.id in chain_ids]
253
  else:
254
  chains = list(structure.get_chains())
255
+ # Chains are treated at lists of residues to make indexing easier
256
+ # and to avoid troubles with residues in PDB files not having a consistent
257
+ # start index
258
+ chain_ids = [chain.id for chain in chains]
259
+ chains = [[res for res in chain.get_residues()] for chain in chains]
260
 
261
  attention_pairs = []
262
  top_residues = []
 
266
 
267
  for i, chain in enumerate(chains):
268
  ec_number = ec_numbers[i] if ec_numbers else None
269
+ sequence = res_to_1letter(chain)
270
  attention = get_attention(
271
  sequence=sequence, model_type=model_type, ec_number=ec_number
272
  )
 
279
  for attn_value, res_1, res_2 in attention_unidirectional:
280
  try:
281
  if not ec_number:
 
282
  coord_1 = chain[res_1]["CA"].coord.tolist()
283
  coord_2 = chain[res_2]["CA"].coord.tolist()
284
  else:
 
311
 
312
  for res, attn_sum in top_n_residues:
313
  coord = chain[res]["CA"].coord.tolist()
314
+ top_residues.append((attn_sum, coord, chain_ids[i], res))
315
 
316
  return attention_pairs, top_residues
hexviz/🧬Attention_Visualization.py CHANGED
@@ -255,8 +255,10 @@ def get_3dview(pdb):
255
 
256
  if label_highest:
257
  for _, _, chain, res in top_residues:
 
258
  xyzview.addResLabels(
259
- {"chain": chain, "resi": res},
 
260
  {
261
  "backgroundColor": "lightgray",
262
  "fontColor": "black",
@@ -280,7 +282,9 @@ Pick a PDB ID, layer and head to visualize attention from the selected protein l
280
  unsafe_allow_html=True,
281
  )
282
 
283
- chain_dict = {f"{chain.id}": chain for chain in list(structure.get_chains())}
 
 
284
  data = []
285
  for att_weight, _, chain, resi in top_residues:
286
  try:
 
255
 
256
  if label_highest:
257
  for _, _, chain, res in top_residues:
258
+ one_indexed_res = res + 1
259
  xyzview.addResLabels(
260
+
261
+ {"chain": chain, "resi": one_indexed_res},
262
  {
263
  "backgroundColor": "lightgray",
264
  "fontColor": "black",
 
282
  unsafe_allow_html=True,
283
  )
284
 
285
+ chain_dict = {
286
+ f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())
287
+ }
288
  data = []
289
  for att_weight, _, chain, resi in top_residues:
290
  try: