aksell commited on
Commit
a71a737
1 Parent(s): 7a18fac

Add ruff, run ruff and black

Browse files
hexviz/attention.py CHANGED
@@ -68,18 +68,14 @@ def res_to_1letter(residues: list[Residue]) -> str:
68
  Residues not in the standard 20 amino acids are replaced with X
69
  """
70
  res_names = [residue.get_resname() for residue in residues]
71
- residues_single_letter = map(
72
- lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), res_names
73
- )
74
 
75
  return "".join(list(residues_single_letter))
76
 
77
 
78
  def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
79
  lines = sequence.split("\n")
80
- cleaned_sequence = "".join(
81
- line.upper() for line in lines if not line.startswith(">")
82
- )
83
  cleaned_sequence = cleaned_sequence.replace(" ", "")
84
  valid_residues = set(Polypeptide.protein_letters_3to1.values())
85
  residues_in_sequence = set(cleaned_sequence)
@@ -87,7 +83,9 @@ def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
87
  # Check if the sequence exceeds the max allowed length
88
  max_sequence_length = 400
89
  if len(cleaned_sequence) > max_sequence_length:
90
- error_message = f"Sequence exceeds the max allowed length of {max_sequence_length} characters"
 
 
91
  return cleaned_sequence, error_message
92
 
93
  illegal_residues = residues_in_sequence - valid_residues
@@ -103,9 +101,7 @@ def remove_special_tokens_and_periods(attentions_tuple, sequence, tokenizer):
103
  tokens = tokenizer.tokenize(sequence)
104
 
105
  indices_to_remove = [
106
- i
107
- for i, token in enumerate(tokens)
108
- if token in {".", "<sep>", "<start>", "<end>", "<pad>"}
109
  ]
110
 
111
  new_attentions = []
@@ -113,9 +109,7 @@ def remove_special_tokens_and_periods(attentions_tuple, sequence, tokenizer):
113
  for attentions in attentions_tuple:
114
  # Remove rows and columns corresponding to special tokens and periods
115
  for idx in sorted(indices_to_remove, reverse=True):
116
- attentions = torch.cat(
117
- (attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2
118
- )
119
  attentions = torch.cat(
120
  (attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3
121
  )
@@ -131,7 +125,7 @@ def get_attention(
131
  sequence: str,
132
  model_type: ModelType = ModelType.TAPE_BERT,
133
  remove_special_tokens: bool = True,
134
- ec_number: list[ECNumber] = None,
135
  ):
136
  """
137
  Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
@@ -153,24 +147,18 @@ def get_attention(
153
  tokenizer, model = get_zymctrl()
154
 
155
  if ec_number:
156
- sequence = f"{'.'.join([ec.number for ec in ec_number])}<sep><start>{sequence}<end><pad>"
157
 
158
  inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device)
159
- attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(
160
- device
161
- )
162
 
163
  with torch.no_grad():
164
- outputs = model(
165
- inputs, attention_mask=attention_mask, output_attentions=True
166
- )
167
  attentions = outputs.attentions
168
 
169
  if ec_number:
170
  # Remove attention to special tokens and periods separating EC number components
171
- attentions = remove_special_tokens_and_periods(
172
- attentions, sequence, tokenizer
173
- )
174
 
175
  # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
176
  attention_squeezed = [torch.squeeze(attention) for attention in attentions]
@@ -196,9 +184,7 @@ def get_attention(
196
  token_idxs = tokenizer.encode(sequence_separated)
197
  inputs = torch.tensor(token_idxs).unsqueeze(0).to(device)
198
  with torch.no_grad():
199
- attentions = model(inputs, output_attentions=True)[
200
- -1
201
- ] # Do you need an attention mask?
202
 
203
  if remove_special_tokens:
204
  # Remove attention to </s> (last) token
@@ -262,17 +248,16 @@ def get_attention_pairs(
262
  top_residues = []
263
 
264
  ec_tag_length = 4
265
- is_tag = lambda x: x < ec_tag_length
 
 
266
 
267
  for i, chain in enumerate(chains):
268
  ec_number = ec_numbers[i] if ec_numbers else None
 
269
  sequence = res_to_1letter(chain)
270
- attention = get_attention(
271
- sequence=sequence, model_type=model_type, ec_number=ec_number
272
- )
273
- attention_unidirectional = unidirectional_avg_filtered(
274
- attention, layer, head, threshold
275
- )
276
 
277
  # Store sum of attention in to a resiue (from the unidirectional attention)
278
  residue_attention = {}
@@ -305,9 +290,7 @@ def get_attention_pairs(
305
  residue_attention.get(res - ec_tag_length, 0) + attn_value
306
  )
307
 
308
- top_n_residues = sorted(
309
- residue_attention.items(), key=lambda x: x[1], reverse=True
310
- )[:top_n]
311
 
312
  for res, attn_sum in top_n_residues:
313
  coord = chain[res]["CA"].coord.tolist()
 
68
  Residues not in the standard 20 amino acids are replaced with X
69
  """
70
  res_names = [residue.get_resname() for residue in residues]
71
+ residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), res_names)
 
 
72
 
73
  return "".join(list(residues_single_letter))
74
 
75
 
76
  def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
77
  lines = sequence.split("\n")
78
+ cleaned_sequence = "".join(line.upper() for line in lines if not line.startswith(">"))
 
 
79
  cleaned_sequence = cleaned_sequence.replace(" ", "")
80
  valid_residues = set(Polypeptide.protein_letters_3to1.values())
81
  residues_in_sequence = set(cleaned_sequence)
 
83
  # Check if the sequence exceeds the max allowed length
84
  max_sequence_length = 400
85
  if len(cleaned_sequence) > max_sequence_length:
86
+ error_message = (
87
+ f"Sequence exceeds the max allowed length of {max_sequence_length} characters"
88
+ )
89
  return cleaned_sequence, error_message
90
 
91
  illegal_residues = residues_in_sequence - valid_residues
 
101
  tokens = tokenizer.tokenize(sequence)
102
 
103
  indices_to_remove = [
104
+ i for i, token in enumerate(tokens) if token in {".", "<sep>", "<start>", "<end>", "<pad>"}
 
 
105
  ]
106
 
107
  new_attentions = []
 
109
  for attentions in attentions_tuple:
110
  # Remove rows and columns corresponding to special tokens and periods
111
  for idx in sorted(indices_to_remove, reverse=True):
112
+ attentions = torch.cat((attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2)
 
 
113
  attentions = torch.cat(
114
  (attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3
115
  )
 
125
  sequence: str,
126
  model_type: ModelType = ModelType.TAPE_BERT,
127
  remove_special_tokens: bool = True,
128
+ ec_number: str = None,
129
  ):
130
  """
131
  Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
 
147
  tokenizer, model = get_zymctrl()
148
 
149
  if ec_number:
150
+ sequence = f"{ec_number}<sep><start>{sequence}<end><pad>"
151
 
152
  inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device)
153
+ attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(device)
 
 
154
 
155
  with torch.no_grad():
156
+ outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
 
 
157
  attentions = outputs.attentions
158
 
159
  if ec_number:
160
  # Remove attention to special tokens and periods separating EC number components
161
+ attentions = remove_special_tokens_and_periods(attentions, sequence, tokenizer)
 
 
162
 
163
  # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
164
  attention_squeezed = [torch.squeeze(attention) for attention in attentions]
 
184
  token_idxs = tokenizer.encode(sequence_separated)
185
  inputs = torch.tensor(token_idxs).unsqueeze(0).to(device)
186
  with torch.no_grad():
187
+ attentions = model(inputs, output_attentions=True)[-1] # Do you need an attention mask?
 
 
188
 
189
  if remove_special_tokens:
190
  # Remove attention to </s> (last) token
 
248
  top_residues = []
249
 
250
  ec_tag_length = 4
251
+
252
+ def is_tag(x):
253
+ return x < ec_tag_length
254
 
255
  for i, chain in enumerate(chains):
256
  ec_number = ec_numbers[i] if ec_numbers else None
257
+ ec_string = ".".join([ec.number for ec in ec_number]) if ec_number else ""
258
  sequence = res_to_1letter(chain)
259
+ attention = get_attention(sequence=sequence, model_type=model_type, ec_number=ec_string)
260
+ attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
 
 
 
 
261
 
262
  # Store sum of attention in to a resiue (from the unidirectional attention)
263
  residue_attention = {}
 
290
  residue_attention.get(res - ec_tag_length, 0) + attn_value
291
  )
292
 
293
+ top_n_residues = sorted(residue_attention.items(), key=lambda x: x[1], reverse=True)[:top_n]
 
 
294
 
295
  for res, attn_sum in top_n_residues:
296
  coord = chain[res]["CA"].coord.tolist()
hexviz/ec_number.py CHANGED
@@ -6,6 +6,4 @@ class ECNumber:
6
  self.radius = radius
7
 
8
  def __str__(self):
9
- return (
10
- f"(EC: {self.number}, Coordinate: {self.coordinate}, Color: {self.color})"
11
- )
 
6
  self.radius = radius
7
 
8
  def __str__(self):
9
+ return f"(EC: {self.number}, Coordinate: {self.coordinate}, Color: {self.color})"
 
 
hexviz/models.py CHANGED
@@ -60,7 +60,5 @@ def get_prot_t5():
60
  tokenizer = T5Tokenizer.from_pretrained(
61
  "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
62
  )
63
- model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
64
- device
65
- )
66
  return tokenizer, model
 
60
  tokenizer = T5Tokenizer.from_pretrained(
61
  "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
62
  )
63
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(device)
 
 
64
  return tokenizer, model
hexviz/pages/1_🗺️Identify_Interesting_Heads.py CHANGED
@@ -27,14 +27,10 @@ models = [
27
  Model(name=ModelType.PROT_T5, layers=24, heads=32),
28
  ]
29
 
30
- with st.expander(
31
- "Input a PDB id, upload a PDB file or input a sequence", expanded=True
32
- ):
33
  pdb_id = select_pdb()
34
  uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
35
- input_sequence = st.text_area(
36
- "3.Input sequence", "", key="input_sequence", max_chars=400
37
- )
38
  sequence, error = clean_and_validate_sequence(input_sequence)
39
  if error:
40
  st.error(error)
@@ -65,7 +61,9 @@ truncated_sequence = sequence[slice_start - 1 : slice_end]
65
  layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
66
 
67
  st.markdown(
68
- f"Each tile is a heatmap of attention for a section of the {source} chain ({chain_selection}) from residue {slice_start} to {slice_end}. Adjust the section length and starting point in the sidebar."
 
 
69
  )
70
 
71
  # TODO: Decide if you should get attention for the full sequence or just the truncated sequence
@@ -74,11 +72,10 @@ attention = get_attention(
74
  sequence=truncated_sequence,
75
  model_type=selected_model.name,
76
  remove_special_tokens=True,
 
77
  )
78
 
79
- fig = plot_tiled_heatmap(
80
- attention, layer_sequence=layer_sequence, head_sequence=head_sequence
81
- )
82
 
83
 
84
  st.pyplot(fig)
 
27
  Model(name=ModelType.PROT_T5, layers=24, heads=32),
28
  ]
29
 
30
+ with st.expander("Input a PDB id, upload a PDB file or input a sequence", expanded=True):
 
 
31
  pdb_id = select_pdb()
32
  uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
33
+ input_sequence = st.text_area("3.Input sequence", "", key="input_sequence", max_chars=400)
 
 
34
  sequence, error = clean_and_validate_sequence(input_sequence)
35
  if error:
36
  st.error(error)
 
61
  layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
62
 
63
  st.markdown(
64
+ f"""Each tile is a heatmap of attention for a section of the {source} chain
65
+ ({chain_selection}) from residue {slice_start} to {slice_end}. Adjust the
66
+ section length and starting point in the sidebar."""
67
  )
68
 
69
  # TODO: Decide if you should get attention for the full sequence or just the truncated sequence
 
72
  sequence=truncated_sequence,
73
  model_type=selected_model.name,
74
  remove_special_tokens=True,
75
+ ec_number=ec_number,
76
  )
77
 
78
+ fig = plot_tiled_heatmap(attention, layer_sequence=layer_sequence, head_sequence=head_sequence)
 
 
79
 
80
 
81
  st.pyplot(fig)
hexviz/pages/2_📄Documentation.py CHANGED
@@ -5,42 +5,65 @@ from hexviz.config import URL
5
  st.markdown(
6
  f"""
7
  ## Protein language models
8
- There has been an explosion of capabilities in natural language processing models in the last few years.
9
- These architectural advances from NLP have proven to work very well for protein sequences, and we now have protein language models (pLMs) that can generate novel functional proteins sequences [ProtGPT2](https://www.nature.com/articles/s42256-022-00499-z)
10
- and auto-encoding models that excel at capturing biophysical features of protein sequences [ProtTrans](https://www.biorxiv.org/content/10.1101/2020.07.12.199554v3).
 
 
 
 
11
 
12
- For an introduction to protein language models for protein design check out [Controllable protein design with language models](https://www.nature.com/articles/s42256-022-00499-z).
 
 
13
 
14
  ## Interpreting protein language models by visualizing attention patterns
15
- With these impressive capabilities it is natural to ask what protein language models are learning and how they work -- we want to **interpret** the models.
16
- In natural language processing **attention analysis** has proven to be a useful tool for interpreting transformer model internals see fex ([Abnar et al. 2020](https://arxiv.org/abs/2005.00928v2)).
17
- [BERTology meets biology](https://arxiv.org/abs/2006.15222) provides a thorough introduction to how we can analyze Transformer protein models through the lens of attention, they show exciting findings such as:
18
- > Attention: (1) captures the folding structure of proteins, connecting amino acids that are far apart in the underlying sequence, but spatially close in the three-dimensional structure, (2) targets binding sites, a key functional component of proteins, and (3) focuses on progressively more complex biophysical properties with increasing layer depth
 
 
 
 
 
 
 
 
19
 
20
- Most existing tools for analyzing and visualizing attention patterns focus on models trained on text. It can be hard to analyze protein sequences using these tools as
21
- sequences can be long and we lack intuition about how the language of proteins work.
22
- BERTology meets biology shows visualizing attention patterns in the context of protein structure can facilitate novel discoveries about what models learn.
23
- [**Hexviz**](https://huggingface.co/spaces/aksell/hexviz) is a tool to simplify analyzing attention patterns in the context of protein structure. We hope this can enable
24
- domain experts to explore and interpret the knowledge contained in pLMs.
 
 
 
 
25
 
26
  ## How to use Hexviz
27
  There are two views:
28
  1. <a href="{URL}Attention_Visualization" target="_self">🧬Attention Visualization</a> Shows attention weights from a single head as red bars between residues on a protein structure.
29
  2. <a href="{URL}Identify_Interesting_Heads" target="_self">🗺️Identify Interesting Heads</a> Plots attention weights between residues as a heatmap for each head in the model.
30
 
31
- The first view is the meat of the application and is where you can investigate how attention patterns map onto the structure of a protein you're interested in.
32
- Use the second view to narrow down to a few heads that you want to investigate attention patterns from in detail.
33
- pLM are large and can have many heads, as an example ProtBERT with it's 30 layers and 16 heads has 480 heads, so we need a way to identify heads with patterns we're interested in.
 
 
 
34
 
35
- The second view is a customizable heatmap plot of attention between residue for all heads and layers in a model. From here it is possible to identify heads that specialize in
36
- a particular attention pattern, such as:
 
37
  1. Vertical lines: Paying attention so a single or a few residues
38
  2. Diagonal: Attention to the same residue or residues in front or behind the current residue.
39
  3. Block attention: Attention is segmented so parts of the sequence are attended to by one part of the sequence.
40
  4. Heterogeneous: More complex attention patterns that are not easily categorized.
41
  TODO: Add examples of attention patterns
42
 
43
- Read more about attention patterns in fex [Revealing the dark secrets of BERT](https://arxiv.org/abs/1908.08593).
 
44
 
45
  ## Protein Language models in Hexviz
46
  Hexviz currently supports the following models:
 
5
  st.markdown(
6
  f"""
7
  ## Protein language models
8
+ There has been an explosion of capabilities in natural language processing
9
+ models in the last few years. These architectural advances from NLP have proven
10
+ to work very well for protein sequences, and we now have protein language models
11
+ (pLMs) that can generate novel functional proteins sequences
12
+ [ProtGPT2](https://www.nature.com/articles/s42256-022-00499-z) and auto-encoding
13
+ models that excel at capturing biophysical features of protein sequences
14
+ [ProtTrans](https://www.biorxiv.org/content/10.1101/2020.07.12.199554v3).
15
 
16
+ For an introduction to protein language models for protein design check out
17
+ [Controllable protein design with language
18
+ models](https://www.nature.com/articles/s42256-022-00499-z).
19
 
20
  ## Interpreting protein language models by visualizing attention patterns
21
+ With these impressive capabilities it is natural to ask what protein language
22
+ models are learning and how they work -- we want to **interpret** the models.
23
+ In natural language processing **attention analysis** has proven to be a useful
24
+ tool for interpreting transformer model internals see fex ([Abnar et al.
25
+ 2020](https://arxiv.org/abs/2005.00928v2)). [BERTology meets
26
+ biology](https://arxiv.org/abs/2006.15222) provides a thorough introduction to
27
+ how we can analyze Transformer protein models through the lens of attention,
28
+ they show exciting findings such as: > Attention: (1) captures the folding
29
+ structure of proteins, connecting amino acids that are far apart in the
30
+ underlying sequence, but spatially close in the three-dimensional structure, (2)
31
+ targets binding sites, a key functional component of proteins, and (3) focuses
32
+ on progressively more complex biophysical properties with increasing layer depth
33
 
34
+ Most existing tools for analyzing and visualizing attention patterns focus on
35
+ models trained on text. It can be hard to analyze protein sequences using these
36
+ tools as sequences can be long and we lack intuition about how the language of
37
+ proteins work. BERTology meets biology shows visualizing attention patterns in
38
+ the context of protein structure can facilitate novel discoveries about what
39
+ models learn. [**Hexviz**](https://huggingface.co/spaces/aksell/hexviz) is a
40
+ tool to simplify analyzing attention patterns in the context of protein
41
+ structure. We hope this can enable domain experts to explore and interpret the
42
+ knowledge contained in pLMs.
43
 
44
  ## How to use Hexviz
45
  There are two views:
46
  1. <a href="{URL}Attention_Visualization" target="_self">🧬Attention Visualization</a> Shows attention weights from a single head as red bars between residues on a protein structure.
47
  2. <a href="{URL}Identify_Interesting_Heads" target="_self">🗺️Identify Interesting Heads</a> Plots attention weights between residues as a heatmap for each head in the model.
48
 
49
+ The first view is the meat of the application and is where you can investigate
50
+ how attention patterns map onto the structure of a protein you're interested in.
51
+ Use the second view to narrow down to a few heads that you want to investigate
52
+ attention patterns from in detail. pLM are large and can have many heads, as an
53
+ example ProtBERT with it's 30 layers and 16 heads has 480 heads, so we need a
54
+ way to identify heads with patterns we're interested in.
55
 
56
+ The second view is a customizable heatmap plot of attention between residue for
57
+ all heads and layers in a model. From here it is possible to identify heads that
58
+ specialize in a particular attention pattern, such as:
59
  1. Vertical lines: Paying attention so a single or a few residues
60
  2. Diagonal: Attention to the same residue or residues in front or behind the current residue.
61
  3. Block attention: Attention is segmented so parts of the sequence are attended to by one part of the sequence.
62
  4. Heterogeneous: More complex attention patterns that are not easily categorized.
63
  TODO: Add examples of attention patterns
64
 
65
+ Read more about attention patterns in fex [Revealing the dark secrets of
66
+ BERT](https://arxiv.org/abs/1908.08593).
67
 
68
  ## Protein Language models in Hexviz
69
  Hexviz currently supports the following models:
hexviz/plot.py CHANGED
@@ -15,30 +15,22 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
15
 
16
  x_size = num_heads * 2
17
  y_size = num_layers * 2
18
- fig, axes = plt.subplots(
19
- num_layers, num_heads, figsize=(x_size, y_size), squeeze=False
20
- )
21
  for i in range(num_layers):
22
  for j in range(num_heads):
23
- axes[i, j].imshow(
24
- tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal"
25
- )
26
  axes[i, j].axis("off")
27
 
28
  # Enumerate the axes
29
  if i == 0:
30
- axes[i, j].set_title(
31
- f"Head {head_sequence[j] + 1}", fontsize=10, y=1.05
32
- )
33
 
34
  # Calculate the row label offset based on the number of columns
35
  offset = 0.02 + (12 - num_heads) * 0.0015
36
  for i, ax_row in enumerate(axes):
37
  row_label = f"{layer_sequence[i]+1}"
38
  row_pos = ax_row[num_heads - 1].get_position()
39
- fig.text(
40
- row_pos.x1 + offset, (row_pos.y1 + row_pos.y0) / 2, row_label, va="center"
41
- )
42
 
43
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
44
  return fig
 
15
 
16
  x_size = num_heads * 2
17
  y_size = num_layers * 2
18
+ fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
 
 
19
  for i in range(num_layers):
20
  for j in range(num_heads):
21
+ axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal")
 
 
22
  axes[i, j].axis("off")
23
 
24
  # Enumerate the axes
25
  if i == 0:
26
+ axes[i, j].set_title(f"Head {head_sequence[j] + 1}", fontsize=10, y=1.05)
 
 
27
 
28
  # Calculate the row label offset based on the number of columns
29
  offset = 0.02 + (12 - num_heads) * 0.0015
30
  for i, ax_row in enumerate(axes):
31
  row_label = f"{layer_sequence[i]+1}"
32
  row_pos = ax_row[num_heads - 1].get_position()
33
+ fig.text(row_pos.x1 + offset, (row_pos.y1 + row_pos.y0) / 2, row_label, va="center")
 
 
34
 
35
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
36
  return fig
hexviz/view.py CHANGED
@@ -18,11 +18,7 @@ def get_selecte_model_index(models):
18
  return 0
19
  else:
20
  return next(
21
- (
22
- i
23
- for i, model in enumerate(models)
24
- if model.name.value == selected_model_name
25
- ),
26
  None,
27
  )
28
 
@@ -89,10 +85,10 @@ def select_protein(pdb_code, uploaded_file, input_sequence):
89
  pdb_str = get_pdb_from_seq(str(input_sequence))
90
  if "selected_chains" in st.session_state:
91
  del st.session_state.selected_chains
92
- source = f"Input sequence + ESM-fold"
93
  elif "uploaded_pdb_str" in st.session_state:
94
  pdb_str = st.session_state.uploaded_pdb_str
95
- source = f"Uploaded file stored in cache"
96
  else:
97
  file = get_pdb_file(pdb_code)
98
  pdb_str = file.read()
@@ -135,7 +131,12 @@ def select_heads_and_layers(sidebar, model):
135
 
136
 
137
  def select_sequence_slice(sequence_length):
138
- st.sidebar.markdown("Sequence segment to plot")
 
 
 
 
 
139
  if "sequence_slice" not in st.session_state:
140
  st.session_state.sequence_slice = (1, min(50, sequence_length))
141
  slice = st.sidebar.slider(
 
18
  return 0
19
  else:
20
  return next(
21
+ (i for i, model in enumerate(models) if model.name.value == selected_model_name),
 
 
 
 
22
  None,
23
  )
24
 
 
85
  pdb_str = get_pdb_from_seq(str(input_sequence))
86
  if "selected_chains" in st.session_state:
87
  del st.session_state.selected_chains
88
+ source = "Input sequence + ESM-fold"
89
  elif "uploaded_pdb_str" in st.session_state:
90
  pdb_str = st.session_state.uploaded_pdb_str
91
+ source = "Uploaded file stored in cache"
92
  else:
93
  file = get_pdb_file(pdb_code)
94
  pdb_str = file.read()
 
131
 
132
 
133
  def select_sequence_slice(sequence_length):
134
+ st.sidebar.markdown(
135
+ """
136
+ Sequence segment to plot
137
+ ---
138
+ """
139
+ )
140
  if "sequence_slice" not in st.session_state:
141
  st.session_state.sequence_slice = (1, min(50, sequence_length))
142
  slice = st.sidebar.slider(
hexviz/🧬Attention_Visualization.py CHANGED
@@ -31,14 +31,10 @@ models = [
31
  Model(name=ModelType.PROT_T5, layers=24, heads=32),
32
  ]
33
 
34
- with st.expander(
35
- "Input a PDB id, upload a PDB file or input a sequence", expanded=True
36
- ):
37
- pdb_id = select_pdb()
38
  uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
39
- input_sequence = st.text_area(
40
- "3.Input sequence", "", key="input_sequence", max_chars=400
41
- )
42
  sequence, error = clean_and_validate_sequence(input_sequence)
43
  if error:
44
  st.error(error)
@@ -59,9 +55,7 @@ selected_chains = st.sidebar.multiselect(
59
  label="Select Chain(s)", options=chains, key="selected_chains"
60
  )
61
 
62
- show_ligands = st.sidebar.checkbox(
63
- "Show ligands", value=st.session_state.get("show_ligands", True)
64
- )
65
  st.session_state.show_ligands = show_ligands
66
 
67
 
@@ -71,9 +65,7 @@ st.sidebar.markdown(
71
  ---
72
  """
73
  )
74
- min_attn = st.sidebar.slider(
75
- "Minimum attention", min_value=0.0, max_value=0.4, value=0.1
76
- )
77
  n_highest_resis = st.sidebar.number_input(
78
  "Num highest attention resis to label", value=2, min_value=1, max_value=100
79
  )
@@ -84,9 +76,7 @@ sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
84
 
85
  with st.sidebar.expander("Label residues manually"):
86
  hl_chain = st.selectbox(label="Chain to label", options=selected_chains, index=0)
87
- hl_resi_list = st.multiselect(
88
- label="Selected Residues", options=list(range(1, 5000))
89
- )
90
 
91
  label_resi = st.checkbox(label="Label Residues", value=True)
92
 
@@ -97,10 +87,13 @@ with left:
97
  with mid:
98
  if "selected_layer" not in st.session_state:
99
  st.session_state["selected_layer"] = 5
100
- layer_one = st.selectbox(
101
- "Layer",
102
- options=[i for i in range(1, selected_model.layers + 1)],
103
- key="selected_layer",
 
 
 
104
  )
105
  layer = layer_one - 1
106
  with right:
@@ -135,9 +128,7 @@ if selected_model.name == ModelType.ZymCTRL:
135
 
136
  if ec_number:
137
  if selected_chains:
138
- shown_chains = [
139
- ch for ch in structure.get_chains() if ch.id in selected_chains
140
- ]
141
  else:
142
  shown_chains = list(structure.get_chains())
143
 
@@ -163,14 +154,9 @@ if selected_model.name == ModelType.ZymCTRL:
163
  reverse_vector = [-v for v in vector]
164
 
165
  # Normalize the reverse vector
166
- reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(
167
- reverse_vector
168
- )
169
  coordinates = [
170
- [
171
- res_1[j] + i * 2 * radius * reverse_vector_normalized[j]
172
- for j in range(3)
173
- ]
174
  for i in range(4)
175
  ]
176
  EC_tag = [
@@ -213,9 +199,7 @@ def get_3dview(pdb):
213
  for chain in hidden_chains:
214
  xyzview.setStyle({"chain": chain}, {"cross": {"hidden": "true"}})
215
  # Hide ligands for chain too
216
- xyzview.addStyle(
217
- {"chain": chain, "hetflag": True}, {"cross": {"hidden": "true"}}
218
- )
219
 
220
  if len(selected_chains) == 1:
221
  xyzview.zoomTo({"chain": f"{selected_chains[0]}"})
@@ -257,7 +241,6 @@ def get_3dview(pdb):
257
  for _, _, chain, res in top_residues:
258
  one_indexed_res = res + 1
259
  xyzview.addResLabels(
260
-
261
  {"chain": chain, "resi": one_indexed_res},
262
  {
263
  "backgroundColor": "lightgray",
@@ -266,9 +249,7 @@ def get_3dview(pdb):
266
  },
267
  )
268
  if sidechain_highest:
269
- xyzview.addStyle(
270
- {"chain": chain, "resi": res}, {"stick": {"radius": 0.2}}
271
- )
272
  return xyzview
273
 
274
 
@@ -282,9 +263,7 @@ Pick a PDB ID, layer and head to visualize attention from the selected protein l
282
  unsafe_allow_html=True,
283
  )
284
 
285
- chain_dict = {
286
- f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())
287
- }
288
  data = []
289
  for att_weight, _, chain, resi in top_residues:
290
  try:
 
31
  Model(name=ModelType.PROT_T5, layers=24, heads=32),
32
  ]
33
 
34
+ with st.expander("Input a PDB id, upload a PDB file or input a sequence", expanded=True):
35
+ pdb_id = select_pdb() or "2WK4"
 
 
36
  uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
37
+ input_sequence = st.text_area("3.Input sequence", "", key="input_sequence", max_chars=400)
 
 
38
  sequence, error = clean_and_validate_sequence(input_sequence)
39
  if error:
40
  st.error(error)
 
55
  label="Select Chain(s)", options=chains, key="selected_chains"
56
  )
57
 
58
+ show_ligands = st.sidebar.checkbox("Show ligands", value=st.session_state.get("show_ligands", True))
 
 
59
  st.session_state.show_ligands = show_ligands
60
 
61
 
 
65
  ---
66
  """
67
  )
68
+ min_attn = st.sidebar.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
 
 
69
  n_highest_resis = st.sidebar.number_input(
70
  "Num highest attention resis to label", value=2, min_value=1, max_value=100
71
  )
 
76
 
77
  with st.sidebar.expander("Label residues manually"):
78
  hl_chain = st.selectbox(label="Chain to label", options=selected_chains, index=0)
79
+ hl_resi_list = st.multiselect(label="Selected Residues", options=list(range(1, 5000)))
 
 
80
 
81
  label_resi = st.checkbox(label="Label Residues", value=True)
82
 
 
87
  with mid:
88
  if "selected_layer" not in st.session_state:
89
  st.session_state["selected_layer"] = 5
90
+ layer_one = (
91
+ st.selectbox(
92
+ "Layer",
93
+ options=[i for i in range(1, selected_model.layers + 1)],
94
+ key="selected_layer",
95
+ )
96
+ or 5
97
  )
98
  layer = layer_one - 1
99
  with right:
 
128
 
129
  if ec_number:
130
  if selected_chains:
131
+ shown_chains = [ch for ch in structure.get_chains() if ch.id in selected_chains]
 
 
132
  else:
133
  shown_chains = list(structure.get_chains())
134
 
 
154
  reverse_vector = [-v for v in vector]
155
 
156
  # Normalize the reverse vector
157
+ reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(reverse_vector)
 
 
158
  coordinates = [
159
+ [res_1[j] + i * 2 * radius * reverse_vector_normalized[j] for j in range(3)]
 
 
 
160
  for i in range(4)
161
  ]
162
  EC_tag = [
 
199
  for chain in hidden_chains:
200
  xyzview.setStyle({"chain": chain}, {"cross": {"hidden": "true"}})
201
  # Hide ligands for chain too
202
+ xyzview.addStyle({"chain": chain, "hetflag": True}, {"cross": {"hidden": "true"}})
 
 
203
 
204
  if len(selected_chains) == 1:
205
  xyzview.zoomTo({"chain": f"{selected_chains[0]}"})
 
241
  for _, _, chain, res in top_residues:
242
  one_indexed_res = res + 1
243
  xyzview.addResLabels(
 
244
  {"chain": chain, "resi": one_indexed_res},
245
  {
246
  "backgroundColor": "lightgray",
 
249
  },
250
  )
251
  if sidechain_highest:
252
+ xyzview.addStyle({"chain": chain, "resi": res}, {"stick": {"radius": 0.2}})
 
 
253
  return xyzview
254
 
255
 
 
263
  unsafe_allow_html=True,
264
  )
265
 
266
+ chain_dict = {f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())}
 
 
267
  data = []
268
  for att_weight, _, chain, resi in top_residues:
269
  try:
poetry.lock CHANGED
@@ -1609,6 +1609,14 @@ pygments = ">=2.13.0,<3.0.0"
1609
  [package.extras]
1610
  jupyter = ["ipywidgets (>=7.5.1,<9)"]
1611
 
 
 
 
 
 
 
 
 
1612
  [[package]]
1613
  name = "s3transfer"
1614
  version = "0.6.0"
@@ -2196,7 +2204,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co
2196
  [metadata]
2197
  lock-version = "1.1"
2198
  python-versions = "^3.10"
2199
- content-hash = "79f191c2f3cc09035f7d0f543aa08c44c0a39de462336ca23eeea26ac29218de"
2200
 
2201
  [metadata.files]
2202
  altair = []
@@ -2428,6 +2436,7 @@ requests = []
2428
  rfc3339-validator = []
2429
  rfc3986-validator = []
2430
  rich = []
 
2431
  s3transfer = []
2432
  scipy = []
2433
  semver = []
 
1609
  [package.extras]
1610
  jupyter = ["ipywidgets (>=7.5.1,<9)"]
1611
 
1612
+ [[package]]
1613
+ name = "ruff"
1614
+ version = "0.0.264"
1615
+ description = "An extremely fast Python linter, written in Rust."
1616
+ category = "main"
1617
+ optional = false
1618
+ python-versions = ">=3.7"
1619
+
1620
  [[package]]
1621
  name = "s3transfer"
1622
  version = "0.6.0"
 
2204
  [metadata]
2205
  lock-version = "1.1"
2206
  python-versions = "^3.10"
2207
+ content-hash = "502949174f23054a4b450dfc0bb16df64c43d7d6c3e60d1adaf2835962223c32"
2208
 
2209
  [metadata.files]
2210
  altair = []
 
2436
  rfc3339-validator = []
2437
  rfc3986-validator = []
2438
  rich = []
2439
+ ruff = []
2440
  s3transfer = []
2441
  scipy = []
2442
  semver = []
pyproject.toml CHANGED
@@ -14,6 +14,7 @@ torch = "^2.0.0"
14
  sentencepiece = "^0.1.97"
15
  tape-proteins = "^0.5"
16
  matplotlib = "^3.7.1"
 
17
 
18
  [tool.poetry.dev-dependencies]
19
  pytest = "^7.2.2"
@@ -21,3 +22,9 @@ pytest = "^7.2.2"
21
  [build-system]
22
  requires = ["poetry-core>=1.0.0"]
23
  build-backend = "poetry.core.masonry.api"
 
 
 
 
 
 
 
14
  sentencepiece = "^0.1.97"
15
  tape-proteins = "^0.5"
16
  matplotlib = "^3.7.1"
17
+ ruff = "^0.0.264"
18
 
19
  [tool.poetry.dev-dependencies]
20
  pytest = "^7.2.2"
 
22
  [build-system]
23
  requires = ["poetry-core>=1.0.0"]
24
  build-backend = "poetry.core.masonry.api"
25
+
26
+ [tool.ruff]
27
+ line-length = 100
28
+
29
+ [tool.black]
30
+ line-length = 100
tests/test_attention.py CHANGED
@@ -1,8 +1,13 @@
1
  import torch
2
  from Bio.PDB.Structure import Structure
3
 
4
- from hexviz.attention import (ModelType, get_attention, get_sequences,
5
- get_structure, unidirectional_avg_filtered)
 
 
 
 
 
6
 
7
 
8
  def test_get_structure():
@@ -12,10 +17,11 @@ def test_get_structure():
12
  assert structure is not None
13
  assert isinstance(structure, Structure)
14
 
 
15
  def test_get_sequences():
16
  pdb_id = "1AKE"
17
  structure = get_structure(pdb_id)
18
-
19
  sequences = get_sequences(structure)
20
 
21
  assert sequences is not None
@@ -30,26 +36,29 @@ def test_get_attention_zymctrl():
30
  result = get_attention("GGG", model_type=ModelType.ZymCTRL)
31
 
32
  assert result is not None
33
- assert result.shape == torch.Size([36,16,3,3])
 
34
 
35
  def test_get_attention_zymctrl_long_chain():
36
- structure = get_structure(pdb_code="6A5J") # 13 residues long
37
 
38
  sequences = get_sequences(structure)
39
 
40
  result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)
41
 
42
  assert result is not None
43
- assert result.shape == torch.Size([36,16,13,13])
 
44
 
45
  def test_get_attention_tape():
46
- structure = get_structure(pdb_code="6A5J") # 13 residues long
47
  sequences = get_sequences(structure)
48
 
49
  result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)
50
 
51
  assert result is not None
52
- assert result.shape == torch.Size([12,12,13,13])
 
53
 
54
  def test_get_attention_prot_bert():
55
 
@@ -58,21 +67,19 @@ def test_get_attention_prot_bert():
58
  assert result is not None
59
  assert result.shape == torch.Size([30, 16, 3, 3])
60
 
 
61
  def test_get_unidirection_avg_filtered():
62
  # 1 head, 1 layer, 4 residues long attention tensor
63
- attention= torch.tensor([[[[1, 2, 3, 4],
64
- [2, 5, 6, 7],
65
- [3, 6, 8, 9],
66
- [4, 7, 9, 11]]]], dtype=torch.float32)
67
 
68
  result = unidirectional_avg_filtered(attention, 0, 0, 0)
69
 
70
  assert result is not None
71
  assert len(result) == 10
72
 
73
- attention = torch.tensor([[[[1, 2, 3],
74
- [2, 5, 6],
75
- [4, 7, 91]]]], dtype=torch.float32)
76
 
77
  result = unidirectional_avg_filtered(attention, 0, 0, 0)
78
 
 
1
  import torch
2
  from Bio.PDB.Structure import Structure
3
 
4
+ from hexviz.attention import (
5
+ ModelType,
6
+ get_attention,
7
+ get_sequences,
8
+ get_structure,
9
+ unidirectional_avg_filtered,
10
+ )
11
 
12
 
13
  def test_get_structure():
 
17
  assert structure is not None
18
  assert isinstance(structure, Structure)
19
 
20
+
21
  def test_get_sequences():
22
  pdb_id = "1AKE"
23
  structure = get_structure(pdb_id)
24
+
25
  sequences = get_sequences(structure)
26
 
27
  assert sequences is not None
 
36
  result = get_attention("GGG", model_type=ModelType.ZymCTRL)
37
 
38
  assert result is not None
39
+ assert result.shape == torch.Size([36, 16, 3, 3])
40
+
41
 
42
  def test_get_attention_zymctrl_long_chain():
43
+ structure = get_structure(pdb_code="6A5J") # 13 residues long
44
 
45
  sequences = get_sequences(structure)
46
 
47
  result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)
48
 
49
  assert result is not None
50
+ assert result.shape == torch.Size([36, 16, 13, 13])
51
+
52
 
53
  def test_get_attention_tape():
54
+ structure = get_structure(pdb_code="6A5J") # 13 residues long
55
  sequences = get_sequences(structure)
56
 
57
  result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)
58
 
59
  assert result is not None
60
+ assert result.shape == torch.Size([12, 12, 13, 13])
61
+
62
 
63
  def test_get_attention_prot_bert():
64
 
 
67
  assert result is not None
68
  assert result.shape == torch.Size([30, 16, 3, 3])
69
 
70
+
71
  def test_get_unidirection_avg_filtered():
72
  # 1 head, 1 layer, 4 residues long attention tensor
73
+ attention = torch.tensor(
74
+ [[[[1, 2, 3, 4], [2, 5, 6, 7], [3, 6, 8, 9], [4, 7, 9, 11]]]], dtype=torch.float32
75
+ )
 
76
 
77
  result = unidirectional_avg_filtered(attention, 0, 0, 0)
78
 
79
  assert result is not None
80
  assert len(result) == 10
81
 
82
+ attention = torch.tensor([[[[1, 2, 3], [2, 5, 6], [4, 7, 91]]]], dtype=torch.float32)
 
 
83
 
84
  result = unidirectional_avg_filtered(attention, 0, 0, 0)
85
 
tests/test_models.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast
3
 
4
  from hexviz.models import get_zymctrl
@@ -13,4 +12,4 @@ def test_get_zymctrl():
13
  tokenizer, model = result
14
 
15
  assert isinstance(tokenizer, GPT2TokenizerFast)
16
- assert isinstance(model, GPT2LMHeadModel)
 
 
1
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast
2
 
3
  from hexviz.models import get_zymctrl
 
12
  tokenizer, model = result
13
 
14
  assert isinstance(tokenizer, GPT2TokenizerFast)
15
+ assert isinstance(model, GPT2LMHeadModel)