Spaces:
Sleeping
Sleeping
Add ruff, run ruff and black
Browse files- hexviz/attention.py +20 -37
- hexviz/ec_number.py +1 -3
- hexviz/models.py +1 -3
- hexviz/pages/1_🗺️Identify_Interesting_Heads.py +7 -10
- hexviz/pages/2_📄Documentation.py +42 -19
- hexviz/plot.py +4 -12
- hexviz/view.py +9 -8
- hexviz/🧬Attention_Visualization.py +19 -40
- poetry.lock +10 -1
- pyproject.toml +7 -0
- tests/test_attention.py +22 -15
- tests/test_models.py +1 -2
hexviz/attention.py
CHANGED
@@ -68,18 +68,14 @@ def res_to_1letter(residues: list[Residue]) -> str:
|
|
68 |
Residues not in the standard 20 amino acids are replaced with X
|
69 |
"""
|
70 |
res_names = [residue.get_resname() for residue in residues]
|
71 |
-
residues_single_letter = map(
|
72 |
-
lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), res_names
|
73 |
-
)
|
74 |
|
75 |
return "".join(list(residues_single_letter))
|
76 |
|
77 |
|
78 |
def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
|
79 |
lines = sequence.split("\n")
|
80 |
-
cleaned_sequence = "".join(
|
81 |
-
line.upper() for line in lines if not line.startswith(">")
|
82 |
-
)
|
83 |
cleaned_sequence = cleaned_sequence.replace(" ", "")
|
84 |
valid_residues = set(Polypeptide.protein_letters_3to1.values())
|
85 |
residues_in_sequence = set(cleaned_sequence)
|
@@ -87,7 +83,9 @@ def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
|
|
87 |
# Check if the sequence exceeds the max allowed length
|
88 |
max_sequence_length = 400
|
89 |
if len(cleaned_sequence) > max_sequence_length:
|
90 |
-
error_message =
|
|
|
|
|
91 |
return cleaned_sequence, error_message
|
92 |
|
93 |
illegal_residues = residues_in_sequence - valid_residues
|
@@ -103,9 +101,7 @@ def remove_special_tokens_and_periods(attentions_tuple, sequence, tokenizer):
|
|
103 |
tokens = tokenizer.tokenize(sequence)
|
104 |
|
105 |
indices_to_remove = [
|
106 |
-
i
|
107 |
-
for i, token in enumerate(tokens)
|
108 |
-
if token in {".", "<sep>", "<start>", "<end>", "<pad>"}
|
109 |
]
|
110 |
|
111 |
new_attentions = []
|
@@ -113,9 +109,7 @@ def remove_special_tokens_and_periods(attentions_tuple, sequence, tokenizer):
|
|
113 |
for attentions in attentions_tuple:
|
114 |
# Remove rows and columns corresponding to special tokens and periods
|
115 |
for idx in sorted(indices_to_remove, reverse=True):
|
116 |
-
attentions = torch.cat(
|
117 |
-
(attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2
|
118 |
-
)
|
119 |
attentions = torch.cat(
|
120 |
(attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3
|
121 |
)
|
@@ -131,7 +125,7 @@ def get_attention(
|
|
131 |
sequence: str,
|
132 |
model_type: ModelType = ModelType.TAPE_BERT,
|
133 |
remove_special_tokens: bool = True,
|
134 |
-
ec_number:
|
135 |
):
|
136 |
"""
|
137 |
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
@@ -153,24 +147,18 @@ def get_attention(
|
|
153 |
tokenizer, model = get_zymctrl()
|
154 |
|
155 |
if ec_number:
|
156 |
-
sequence = f"{
|
157 |
|
158 |
inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device)
|
159 |
-
attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(
|
160 |
-
device
|
161 |
-
)
|
162 |
|
163 |
with torch.no_grad():
|
164 |
-
outputs = model(
|
165 |
-
inputs, attention_mask=attention_mask, output_attentions=True
|
166 |
-
)
|
167 |
attentions = outputs.attentions
|
168 |
|
169 |
if ec_number:
|
170 |
# Remove attention to special tokens and periods separating EC number components
|
171 |
-
attentions = remove_special_tokens_and_periods(
|
172 |
-
attentions, sequence, tokenizer
|
173 |
-
)
|
174 |
|
175 |
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
176 |
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
@@ -196,9 +184,7 @@ def get_attention(
|
|
196 |
token_idxs = tokenizer.encode(sequence_separated)
|
197 |
inputs = torch.tensor(token_idxs).unsqueeze(0).to(device)
|
198 |
with torch.no_grad():
|
199 |
-
attentions = model(inputs, output_attentions=True)[
|
200 |
-
-1
|
201 |
-
] # Do you need an attention mask?
|
202 |
|
203 |
if remove_special_tokens:
|
204 |
# Remove attention to </s> (last) token
|
@@ -262,17 +248,16 @@ def get_attention_pairs(
|
|
262 |
top_residues = []
|
263 |
|
264 |
ec_tag_length = 4
|
265 |
-
|
|
|
|
|
266 |
|
267 |
for i, chain in enumerate(chains):
|
268 |
ec_number = ec_numbers[i] if ec_numbers else None
|
|
|
269 |
sequence = res_to_1letter(chain)
|
270 |
-
attention = get_attention(
|
271 |
-
|
272 |
-
)
|
273 |
-
attention_unidirectional = unidirectional_avg_filtered(
|
274 |
-
attention, layer, head, threshold
|
275 |
-
)
|
276 |
|
277 |
# Store sum of attention in to a resiue (from the unidirectional attention)
|
278 |
residue_attention = {}
|
@@ -305,9 +290,7 @@ def get_attention_pairs(
|
|
305 |
residue_attention.get(res - ec_tag_length, 0) + attn_value
|
306 |
)
|
307 |
|
308 |
-
top_n_residues = sorted(
|
309 |
-
residue_attention.items(), key=lambda x: x[1], reverse=True
|
310 |
-
)[:top_n]
|
311 |
|
312 |
for res, attn_sum in top_n_residues:
|
313 |
coord = chain[res]["CA"].coord.tolist()
|
|
|
68 |
Residues not in the standard 20 amino acids are replaced with X
|
69 |
"""
|
70 |
res_names = [residue.get_resname() for residue in residues]
|
71 |
+
residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), res_names)
|
|
|
|
|
72 |
|
73 |
return "".join(list(residues_single_letter))
|
74 |
|
75 |
|
76 |
def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
|
77 |
lines = sequence.split("\n")
|
78 |
+
cleaned_sequence = "".join(line.upper() for line in lines if not line.startswith(">"))
|
|
|
|
|
79 |
cleaned_sequence = cleaned_sequence.replace(" ", "")
|
80 |
valid_residues = set(Polypeptide.protein_letters_3to1.values())
|
81 |
residues_in_sequence = set(cleaned_sequence)
|
|
|
83 |
# Check if the sequence exceeds the max allowed length
|
84 |
max_sequence_length = 400
|
85 |
if len(cleaned_sequence) > max_sequence_length:
|
86 |
+
error_message = (
|
87 |
+
f"Sequence exceeds the max allowed length of {max_sequence_length} characters"
|
88 |
+
)
|
89 |
return cleaned_sequence, error_message
|
90 |
|
91 |
illegal_residues = residues_in_sequence - valid_residues
|
|
|
101 |
tokens = tokenizer.tokenize(sequence)
|
102 |
|
103 |
indices_to_remove = [
|
104 |
+
i for i, token in enumerate(tokens) if token in {".", "<sep>", "<start>", "<end>", "<pad>"}
|
|
|
|
|
105 |
]
|
106 |
|
107 |
new_attentions = []
|
|
|
109 |
for attentions in attentions_tuple:
|
110 |
# Remove rows and columns corresponding to special tokens and periods
|
111 |
for idx in sorted(indices_to_remove, reverse=True):
|
112 |
+
attentions = torch.cat((attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2)
|
|
|
|
|
113 |
attentions = torch.cat(
|
114 |
(attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3
|
115 |
)
|
|
|
125 |
sequence: str,
|
126 |
model_type: ModelType = ModelType.TAPE_BERT,
|
127 |
remove_special_tokens: bool = True,
|
128 |
+
ec_number: str = None,
|
129 |
):
|
130 |
"""
|
131 |
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
|
|
147 |
tokenizer, model = get_zymctrl()
|
148 |
|
149 |
if ec_number:
|
150 |
+
sequence = f"{ec_number}<sep><start>{sequence}<end><pad>"
|
151 |
|
152 |
inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device)
|
153 |
+
attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(device)
|
|
|
|
|
154 |
|
155 |
with torch.no_grad():
|
156 |
+
outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
|
|
|
|
|
157 |
attentions = outputs.attentions
|
158 |
|
159 |
if ec_number:
|
160 |
# Remove attention to special tokens and periods separating EC number components
|
161 |
+
attentions = remove_special_tokens_and_periods(attentions, sequence, tokenizer)
|
|
|
|
|
162 |
|
163 |
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
164 |
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
|
|
184 |
token_idxs = tokenizer.encode(sequence_separated)
|
185 |
inputs = torch.tensor(token_idxs).unsqueeze(0).to(device)
|
186 |
with torch.no_grad():
|
187 |
+
attentions = model(inputs, output_attentions=True)[-1] # Do you need an attention mask?
|
|
|
|
|
188 |
|
189 |
if remove_special_tokens:
|
190 |
# Remove attention to </s> (last) token
|
|
|
248 |
top_residues = []
|
249 |
|
250 |
ec_tag_length = 4
|
251 |
+
|
252 |
+
def is_tag(x):
|
253 |
+
return x < ec_tag_length
|
254 |
|
255 |
for i, chain in enumerate(chains):
|
256 |
ec_number = ec_numbers[i] if ec_numbers else None
|
257 |
+
ec_string = ".".join([ec.number for ec in ec_number]) if ec_number else ""
|
258 |
sequence = res_to_1letter(chain)
|
259 |
+
attention = get_attention(sequence=sequence, model_type=model_type, ec_number=ec_string)
|
260 |
+
attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
|
|
|
|
|
|
|
|
|
261 |
|
262 |
# Store sum of attention in to a resiue (from the unidirectional attention)
|
263 |
residue_attention = {}
|
|
|
290 |
residue_attention.get(res - ec_tag_length, 0) + attn_value
|
291 |
)
|
292 |
|
293 |
+
top_n_residues = sorted(residue_attention.items(), key=lambda x: x[1], reverse=True)[:top_n]
|
|
|
|
|
294 |
|
295 |
for res, attn_sum in top_n_residues:
|
296 |
coord = chain[res]["CA"].coord.tolist()
|
hexviz/ec_number.py
CHANGED
@@ -6,6 +6,4 @@ class ECNumber:
|
|
6 |
self.radius = radius
|
7 |
|
8 |
def __str__(self):
|
9 |
-
return (
|
10 |
-
f"(EC: {self.number}, Coordinate: {self.coordinate}, Color: {self.color})"
|
11 |
-
)
|
|
|
6 |
self.radius = radius
|
7 |
|
8 |
def __str__(self):
|
9 |
+
return f"(EC: {self.number}, Coordinate: {self.coordinate}, Color: {self.color})"
|
|
|
|
hexviz/models.py
CHANGED
@@ -60,7 +60,5 @@ def get_prot_t5():
|
|
60 |
tokenizer = T5Tokenizer.from_pretrained(
|
61 |
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
|
62 |
)
|
63 |
-
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
|
64 |
-
device
|
65 |
-
)
|
66 |
return tokenizer, model
|
|
|
60 |
tokenizer = T5Tokenizer.from_pretrained(
|
61 |
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
|
62 |
)
|
63 |
+
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(device)
|
|
|
|
|
64 |
return tokenizer, model
|
hexviz/pages/1_🗺️Identify_Interesting_Heads.py
CHANGED
@@ -27,14 +27,10 @@ models = [
|
|
27 |
Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
28 |
]
|
29 |
|
30 |
-
with st.expander(
|
31 |
-
"Input a PDB id, upload a PDB file or input a sequence", expanded=True
|
32 |
-
):
|
33 |
pdb_id = select_pdb()
|
34 |
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
|
35 |
-
input_sequence = st.text_area(
|
36 |
-
"3.Input sequence", "", key="input_sequence", max_chars=400
|
37 |
-
)
|
38 |
sequence, error = clean_and_validate_sequence(input_sequence)
|
39 |
if error:
|
40 |
st.error(error)
|
@@ -65,7 +61,9 @@ truncated_sequence = sequence[slice_start - 1 : slice_end]
|
|
65 |
layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
|
66 |
|
67 |
st.markdown(
|
68 |
-
f"Each tile is a heatmap of attention for a section of the {source} chain
|
|
|
|
|
69 |
)
|
70 |
|
71 |
# TODO: Decide if you should get attention for the full sequence or just the truncated sequence
|
@@ -74,11 +72,10 @@ attention = get_attention(
|
|
74 |
sequence=truncated_sequence,
|
75 |
model_type=selected_model.name,
|
76 |
remove_special_tokens=True,
|
|
|
77 |
)
|
78 |
|
79 |
-
fig = plot_tiled_heatmap(
|
80 |
-
attention, layer_sequence=layer_sequence, head_sequence=head_sequence
|
81 |
-
)
|
82 |
|
83 |
|
84 |
st.pyplot(fig)
|
|
|
27 |
Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
28 |
]
|
29 |
|
30 |
+
with st.expander("Input a PDB id, upload a PDB file or input a sequence", expanded=True):
|
|
|
|
|
31 |
pdb_id = select_pdb()
|
32 |
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
|
33 |
+
input_sequence = st.text_area("3.Input sequence", "", key="input_sequence", max_chars=400)
|
|
|
|
|
34 |
sequence, error = clean_and_validate_sequence(input_sequence)
|
35 |
if error:
|
36 |
st.error(error)
|
|
|
61 |
layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
|
62 |
|
63 |
st.markdown(
|
64 |
+
f"""Each tile is a heatmap of attention for a section of the {source} chain
|
65 |
+
({chain_selection}) from residue {slice_start} to {slice_end}. Adjust the
|
66 |
+
section length and starting point in the sidebar."""
|
67 |
)
|
68 |
|
69 |
# TODO: Decide if you should get attention for the full sequence or just the truncated sequence
|
|
|
72 |
sequence=truncated_sequence,
|
73 |
model_type=selected_model.name,
|
74 |
remove_special_tokens=True,
|
75 |
+
ec_number=ec_number,
|
76 |
)
|
77 |
|
78 |
+
fig = plot_tiled_heatmap(attention, layer_sequence=layer_sequence, head_sequence=head_sequence)
|
|
|
|
|
79 |
|
80 |
|
81 |
st.pyplot(fig)
|
hexviz/pages/2_📄Documentation.py
CHANGED
@@ -5,42 +5,65 @@ from hexviz.config import URL
|
|
5 |
st.markdown(
|
6 |
f"""
|
7 |
## Protein language models
|
8 |
-
There has been an explosion of capabilities in natural language processing
|
9 |
-
These architectural advances from NLP have proven
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
For an introduction to protein language models for protein design check out
|
|
|
|
|
13 |
|
14 |
## Interpreting protein language models by visualizing attention patterns
|
15 |
-
With these impressive capabilities it is natural to ask what protein language
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
Most existing tools for analyzing and visualizing attention patterns focus on
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
|
26 |
## How to use Hexviz
|
27 |
There are two views:
|
28 |
1. <a href="{URL}Attention_Visualization" target="_self">🧬Attention Visualization</a> Shows attention weights from a single head as red bars between residues on a protein structure.
|
29 |
2. <a href="{URL}Identify_Interesting_Heads" target="_self">🗺️Identify Interesting Heads</a> Plots attention weights between residues as a heatmap for each head in the model.
|
30 |
|
31 |
-
The first view is the meat of the application and is where you can investigate
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
34 |
|
35 |
-
The second view is a customizable heatmap plot of attention between residue for
|
36 |
-
a
|
|
|
37 |
1. Vertical lines: Paying attention so a single or a few residues
|
38 |
2. Diagonal: Attention to the same residue or residues in front or behind the current residue.
|
39 |
3. Block attention: Attention is segmented so parts of the sequence are attended to by one part of the sequence.
|
40 |
4. Heterogeneous: More complex attention patterns that are not easily categorized.
|
41 |
TODO: Add examples of attention patterns
|
42 |
|
43 |
-
Read more about attention patterns in fex [Revealing the dark secrets of
|
|
|
44 |
|
45 |
## Protein Language models in Hexviz
|
46 |
Hexviz currently supports the following models:
|
|
|
5 |
st.markdown(
|
6 |
f"""
|
7 |
## Protein language models
|
8 |
+
There has been an explosion of capabilities in natural language processing
|
9 |
+
models in the last few years. These architectural advances from NLP have proven
|
10 |
+
to work very well for protein sequences, and we now have protein language models
|
11 |
+
(pLMs) that can generate novel functional proteins sequences
|
12 |
+
[ProtGPT2](https://www.nature.com/articles/s42256-022-00499-z) and auto-encoding
|
13 |
+
models that excel at capturing biophysical features of protein sequences
|
14 |
+
[ProtTrans](https://www.biorxiv.org/content/10.1101/2020.07.12.199554v3).
|
15 |
|
16 |
+
For an introduction to protein language models for protein design check out
|
17 |
+
[Controllable protein design with language
|
18 |
+
models](https://www.nature.com/articles/s42256-022-00499-z).
|
19 |
|
20 |
## Interpreting protein language models by visualizing attention patterns
|
21 |
+
With these impressive capabilities it is natural to ask what protein language
|
22 |
+
models are learning and how they work -- we want to **interpret** the models.
|
23 |
+
In natural language processing **attention analysis** has proven to be a useful
|
24 |
+
tool for interpreting transformer model internals see fex ([Abnar et al.
|
25 |
+
2020](https://arxiv.org/abs/2005.00928v2)). [BERTology meets
|
26 |
+
biology](https://arxiv.org/abs/2006.15222) provides a thorough introduction to
|
27 |
+
how we can analyze Transformer protein models through the lens of attention,
|
28 |
+
they show exciting findings such as: > Attention: (1) captures the folding
|
29 |
+
structure of proteins, connecting amino acids that are far apart in the
|
30 |
+
underlying sequence, but spatially close in the three-dimensional structure, (2)
|
31 |
+
targets binding sites, a key functional component of proteins, and (3) focuses
|
32 |
+
on progressively more complex biophysical properties with increasing layer depth
|
33 |
|
34 |
+
Most existing tools for analyzing and visualizing attention patterns focus on
|
35 |
+
models trained on text. It can be hard to analyze protein sequences using these
|
36 |
+
tools as sequences can be long and we lack intuition about how the language of
|
37 |
+
proteins work. BERTology meets biology shows visualizing attention patterns in
|
38 |
+
the context of protein structure can facilitate novel discoveries about what
|
39 |
+
models learn. [**Hexviz**](https://huggingface.co/spaces/aksell/hexviz) is a
|
40 |
+
tool to simplify analyzing attention patterns in the context of protein
|
41 |
+
structure. We hope this can enable domain experts to explore and interpret the
|
42 |
+
knowledge contained in pLMs.
|
43 |
|
44 |
## How to use Hexviz
|
45 |
There are two views:
|
46 |
1. <a href="{URL}Attention_Visualization" target="_self">🧬Attention Visualization</a> Shows attention weights from a single head as red bars between residues on a protein structure.
|
47 |
2. <a href="{URL}Identify_Interesting_Heads" target="_self">🗺️Identify Interesting Heads</a> Plots attention weights between residues as a heatmap for each head in the model.
|
48 |
|
49 |
+
The first view is the meat of the application and is where you can investigate
|
50 |
+
how attention patterns map onto the structure of a protein you're interested in.
|
51 |
+
Use the second view to narrow down to a few heads that you want to investigate
|
52 |
+
attention patterns from in detail. pLM are large and can have many heads, as an
|
53 |
+
example ProtBERT with it's 30 layers and 16 heads has 480 heads, so we need a
|
54 |
+
way to identify heads with patterns we're interested in.
|
55 |
|
56 |
+
The second view is a customizable heatmap plot of attention between residue for
|
57 |
+
all heads and layers in a model. From here it is possible to identify heads that
|
58 |
+
specialize in a particular attention pattern, such as:
|
59 |
1. Vertical lines: Paying attention so a single or a few residues
|
60 |
2. Diagonal: Attention to the same residue or residues in front or behind the current residue.
|
61 |
3. Block attention: Attention is segmented so parts of the sequence are attended to by one part of the sequence.
|
62 |
4. Heterogeneous: More complex attention patterns that are not easily categorized.
|
63 |
TODO: Add examples of attention patterns
|
64 |
|
65 |
+
Read more about attention patterns in fex [Revealing the dark secrets of
|
66 |
+
BERT](https://arxiv.org/abs/1908.08593).
|
67 |
|
68 |
## Protein Language models in Hexviz
|
69 |
Hexviz currently supports the following models:
|
hexviz/plot.py
CHANGED
@@ -15,30 +15,22 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
|
|
15 |
|
16 |
x_size = num_heads * 2
|
17 |
y_size = num_layers * 2
|
18 |
-
fig, axes = plt.subplots(
|
19 |
-
num_layers, num_heads, figsize=(x_size, y_size), squeeze=False
|
20 |
-
)
|
21 |
for i in range(num_layers):
|
22 |
for j in range(num_heads):
|
23 |
-
axes[i, j].imshow(
|
24 |
-
tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal"
|
25 |
-
)
|
26 |
axes[i, j].axis("off")
|
27 |
|
28 |
# Enumerate the axes
|
29 |
if i == 0:
|
30 |
-
axes[i, j].set_title(
|
31 |
-
f"Head {head_sequence[j] + 1}", fontsize=10, y=1.05
|
32 |
-
)
|
33 |
|
34 |
# Calculate the row label offset based on the number of columns
|
35 |
offset = 0.02 + (12 - num_heads) * 0.0015
|
36 |
for i, ax_row in enumerate(axes):
|
37 |
row_label = f"{layer_sequence[i]+1}"
|
38 |
row_pos = ax_row[num_heads - 1].get_position()
|
39 |
-
fig.text(
|
40 |
-
row_pos.x1 + offset, (row_pos.y1 + row_pos.y0) / 2, row_label, va="center"
|
41 |
-
)
|
42 |
|
43 |
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
44 |
return fig
|
|
|
15 |
|
16 |
x_size = num_heads * 2
|
17 |
y_size = num_layers * 2
|
18 |
+
fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
|
|
|
|
|
19 |
for i in range(num_layers):
|
20 |
for j in range(num_heads):
|
21 |
+
axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal")
|
|
|
|
|
22 |
axes[i, j].axis("off")
|
23 |
|
24 |
# Enumerate the axes
|
25 |
if i == 0:
|
26 |
+
axes[i, j].set_title(f"Head {head_sequence[j] + 1}", fontsize=10, y=1.05)
|
|
|
|
|
27 |
|
28 |
# Calculate the row label offset based on the number of columns
|
29 |
offset = 0.02 + (12 - num_heads) * 0.0015
|
30 |
for i, ax_row in enumerate(axes):
|
31 |
row_label = f"{layer_sequence[i]+1}"
|
32 |
row_pos = ax_row[num_heads - 1].get_position()
|
33 |
+
fig.text(row_pos.x1 + offset, (row_pos.y1 + row_pos.y0) / 2, row_label, va="center")
|
|
|
|
|
34 |
|
35 |
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
36 |
return fig
|
hexviz/view.py
CHANGED
@@ -18,11 +18,7 @@ def get_selecte_model_index(models):
|
|
18 |
return 0
|
19 |
else:
|
20 |
return next(
|
21 |
-
(
|
22 |
-
i
|
23 |
-
for i, model in enumerate(models)
|
24 |
-
if model.name.value == selected_model_name
|
25 |
-
),
|
26 |
None,
|
27 |
)
|
28 |
|
@@ -89,10 +85,10 @@ def select_protein(pdb_code, uploaded_file, input_sequence):
|
|
89 |
pdb_str = get_pdb_from_seq(str(input_sequence))
|
90 |
if "selected_chains" in st.session_state:
|
91 |
del st.session_state.selected_chains
|
92 |
-
source =
|
93 |
elif "uploaded_pdb_str" in st.session_state:
|
94 |
pdb_str = st.session_state.uploaded_pdb_str
|
95 |
-
source =
|
96 |
else:
|
97 |
file = get_pdb_file(pdb_code)
|
98 |
pdb_str = file.read()
|
@@ -135,7 +131,12 @@ def select_heads_and_layers(sidebar, model):
|
|
135 |
|
136 |
|
137 |
def select_sequence_slice(sequence_length):
|
138 |
-
st.sidebar.markdown(
|
|
|
|
|
|
|
|
|
|
|
139 |
if "sequence_slice" not in st.session_state:
|
140 |
st.session_state.sequence_slice = (1, min(50, sequence_length))
|
141 |
slice = st.sidebar.slider(
|
|
|
18 |
return 0
|
19 |
else:
|
20 |
return next(
|
21 |
+
(i for i, model in enumerate(models) if model.name.value == selected_model_name),
|
|
|
|
|
|
|
|
|
22 |
None,
|
23 |
)
|
24 |
|
|
|
85 |
pdb_str = get_pdb_from_seq(str(input_sequence))
|
86 |
if "selected_chains" in st.session_state:
|
87 |
del st.session_state.selected_chains
|
88 |
+
source = "Input sequence + ESM-fold"
|
89 |
elif "uploaded_pdb_str" in st.session_state:
|
90 |
pdb_str = st.session_state.uploaded_pdb_str
|
91 |
+
source = "Uploaded file stored in cache"
|
92 |
else:
|
93 |
file = get_pdb_file(pdb_code)
|
94 |
pdb_str = file.read()
|
|
|
131 |
|
132 |
|
133 |
def select_sequence_slice(sequence_length):
|
134 |
+
st.sidebar.markdown(
|
135 |
+
"""
|
136 |
+
Sequence segment to plot
|
137 |
+
---
|
138 |
+
"""
|
139 |
+
)
|
140 |
if "sequence_slice" not in st.session_state:
|
141 |
st.session_state.sequence_slice = (1, min(50, sequence_length))
|
142 |
slice = st.sidebar.slider(
|
hexviz/🧬Attention_Visualization.py
CHANGED
@@ -31,14 +31,10 @@ models = [
|
|
31 |
Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
32 |
]
|
33 |
|
34 |
-
with st.expander(
|
35 |
-
|
36 |
-
):
|
37 |
-
pdb_id = select_pdb()
|
38 |
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
|
39 |
-
input_sequence = st.text_area(
|
40 |
-
"3.Input sequence", "", key="input_sequence", max_chars=400
|
41 |
-
)
|
42 |
sequence, error = clean_and_validate_sequence(input_sequence)
|
43 |
if error:
|
44 |
st.error(error)
|
@@ -59,9 +55,7 @@ selected_chains = st.sidebar.multiselect(
|
|
59 |
label="Select Chain(s)", options=chains, key="selected_chains"
|
60 |
)
|
61 |
|
62 |
-
show_ligands = st.sidebar.checkbox(
|
63 |
-
"Show ligands", value=st.session_state.get("show_ligands", True)
|
64 |
-
)
|
65 |
st.session_state.show_ligands = show_ligands
|
66 |
|
67 |
|
@@ -71,9 +65,7 @@ st.sidebar.markdown(
|
|
71 |
---
|
72 |
"""
|
73 |
)
|
74 |
-
min_attn = st.sidebar.slider(
|
75 |
-
"Minimum attention", min_value=0.0, max_value=0.4, value=0.1
|
76 |
-
)
|
77 |
n_highest_resis = st.sidebar.number_input(
|
78 |
"Num highest attention resis to label", value=2, min_value=1, max_value=100
|
79 |
)
|
@@ -84,9 +76,7 @@ sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
|
|
84 |
|
85 |
with st.sidebar.expander("Label residues manually"):
|
86 |
hl_chain = st.selectbox(label="Chain to label", options=selected_chains, index=0)
|
87 |
-
hl_resi_list = st.multiselect(
|
88 |
-
label="Selected Residues", options=list(range(1, 5000))
|
89 |
-
)
|
90 |
|
91 |
label_resi = st.checkbox(label="Label Residues", value=True)
|
92 |
|
@@ -97,10 +87,13 @@ with left:
|
|
97 |
with mid:
|
98 |
if "selected_layer" not in st.session_state:
|
99 |
st.session_state["selected_layer"] = 5
|
100 |
-
layer_one =
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
104 |
)
|
105 |
layer = layer_one - 1
|
106 |
with right:
|
@@ -135,9 +128,7 @@ if selected_model.name == ModelType.ZymCTRL:
|
|
135 |
|
136 |
if ec_number:
|
137 |
if selected_chains:
|
138 |
-
shown_chains = [
|
139 |
-
ch for ch in structure.get_chains() if ch.id in selected_chains
|
140 |
-
]
|
141 |
else:
|
142 |
shown_chains = list(structure.get_chains())
|
143 |
|
@@ -163,14 +154,9 @@ if selected_model.name == ModelType.ZymCTRL:
|
|
163 |
reverse_vector = [-v for v in vector]
|
164 |
|
165 |
# Normalize the reverse vector
|
166 |
-
reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(
|
167 |
-
reverse_vector
|
168 |
-
)
|
169 |
coordinates = [
|
170 |
-
[
|
171 |
-
res_1[j] + i * 2 * radius * reverse_vector_normalized[j]
|
172 |
-
for j in range(3)
|
173 |
-
]
|
174 |
for i in range(4)
|
175 |
]
|
176 |
EC_tag = [
|
@@ -213,9 +199,7 @@ def get_3dview(pdb):
|
|
213 |
for chain in hidden_chains:
|
214 |
xyzview.setStyle({"chain": chain}, {"cross": {"hidden": "true"}})
|
215 |
# Hide ligands for chain too
|
216 |
-
xyzview.addStyle(
|
217 |
-
{"chain": chain, "hetflag": True}, {"cross": {"hidden": "true"}}
|
218 |
-
)
|
219 |
|
220 |
if len(selected_chains) == 1:
|
221 |
xyzview.zoomTo({"chain": f"{selected_chains[0]}"})
|
@@ -257,7 +241,6 @@ def get_3dview(pdb):
|
|
257 |
for _, _, chain, res in top_residues:
|
258 |
one_indexed_res = res + 1
|
259 |
xyzview.addResLabels(
|
260 |
-
|
261 |
{"chain": chain, "resi": one_indexed_res},
|
262 |
{
|
263 |
"backgroundColor": "lightgray",
|
@@ -266,9 +249,7 @@ def get_3dview(pdb):
|
|
266 |
},
|
267 |
)
|
268 |
if sidechain_highest:
|
269 |
-
xyzview.addStyle(
|
270 |
-
{"chain": chain, "resi": res}, {"stick": {"radius": 0.2}}
|
271 |
-
)
|
272 |
return xyzview
|
273 |
|
274 |
|
@@ -282,9 +263,7 @@ Pick a PDB ID, layer and head to visualize attention from the selected protein l
|
|
282 |
unsafe_allow_html=True,
|
283 |
)
|
284 |
|
285 |
-
chain_dict = {
|
286 |
-
f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())
|
287 |
-
}
|
288 |
data = []
|
289 |
for att_weight, _, chain, resi in top_residues:
|
290 |
try:
|
|
|
31 |
Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
32 |
]
|
33 |
|
34 |
+
with st.expander("Input a PDB id, upload a PDB file or input a sequence", expanded=True):
|
35 |
+
pdb_id = select_pdb() or "2WK4"
|
|
|
|
|
36 |
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
|
37 |
+
input_sequence = st.text_area("3.Input sequence", "", key="input_sequence", max_chars=400)
|
|
|
|
|
38 |
sequence, error = clean_and_validate_sequence(input_sequence)
|
39 |
if error:
|
40 |
st.error(error)
|
|
|
55 |
label="Select Chain(s)", options=chains, key="selected_chains"
|
56 |
)
|
57 |
|
58 |
+
show_ligands = st.sidebar.checkbox("Show ligands", value=st.session_state.get("show_ligands", True))
|
|
|
|
|
59 |
st.session_state.show_ligands = show_ligands
|
60 |
|
61 |
|
|
|
65 |
---
|
66 |
"""
|
67 |
)
|
68 |
+
min_attn = st.sidebar.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
|
|
|
|
|
69 |
n_highest_resis = st.sidebar.number_input(
|
70 |
"Num highest attention resis to label", value=2, min_value=1, max_value=100
|
71 |
)
|
|
|
76 |
|
77 |
with st.sidebar.expander("Label residues manually"):
|
78 |
hl_chain = st.selectbox(label="Chain to label", options=selected_chains, index=0)
|
79 |
+
hl_resi_list = st.multiselect(label="Selected Residues", options=list(range(1, 5000)))
|
|
|
|
|
80 |
|
81 |
label_resi = st.checkbox(label="Label Residues", value=True)
|
82 |
|
|
|
87 |
with mid:
|
88 |
if "selected_layer" not in st.session_state:
|
89 |
st.session_state["selected_layer"] = 5
|
90 |
+
layer_one = (
|
91 |
+
st.selectbox(
|
92 |
+
"Layer",
|
93 |
+
options=[i for i in range(1, selected_model.layers + 1)],
|
94 |
+
key="selected_layer",
|
95 |
+
)
|
96 |
+
or 5
|
97 |
)
|
98 |
layer = layer_one - 1
|
99 |
with right:
|
|
|
128 |
|
129 |
if ec_number:
|
130 |
if selected_chains:
|
131 |
+
shown_chains = [ch for ch in structure.get_chains() if ch.id in selected_chains]
|
|
|
|
|
132 |
else:
|
133 |
shown_chains = list(structure.get_chains())
|
134 |
|
|
|
154 |
reverse_vector = [-v for v in vector]
|
155 |
|
156 |
# Normalize the reverse vector
|
157 |
+
reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(reverse_vector)
|
|
|
|
|
158 |
coordinates = [
|
159 |
+
[res_1[j] + i * 2 * radius * reverse_vector_normalized[j] for j in range(3)]
|
|
|
|
|
|
|
160 |
for i in range(4)
|
161 |
]
|
162 |
EC_tag = [
|
|
|
199 |
for chain in hidden_chains:
|
200 |
xyzview.setStyle({"chain": chain}, {"cross": {"hidden": "true"}})
|
201 |
# Hide ligands for chain too
|
202 |
+
xyzview.addStyle({"chain": chain, "hetflag": True}, {"cross": {"hidden": "true"}})
|
|
|
|
|
203 |
|
204 |
if len(selected_chains) == 1:
|
205 |
xyzview.zoomTo({"chain": f"{selected_chains[0]}"})
|
|
|
241 |
for _, _, chain, res in top_residues:
|
242 |
one_indexed_res = res + 1
|
243 |
xyzview.addResLabels(
|
|
|
244 |
{"chain": chain, "resi": one_indexed_res},
|
245 |
{
|
246 |
"backgroundColor": "lightgray",
|
|
|
249 |
},
|
250 |
)
|
251 |
if sidechain_highest:
|
252 |
+
xyzview.addStyle({"chain": chain, "resi": res}, {"stick": {"radius": 0.2}})
|
|
|
|
|
253 |
return xyzview
|
254 |
|
255 |
|
|
|
263 |
unsafe_allow_html=True,
|
264 |
)
|
265 |
|
266 |
+
chain_dict = {f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())}
|
|
|
|
|
267 |
data = []
|
268 |
for att_weight, _, chain, resi in top_residues:
|
269 |
try:
|
poetry.lock
CHANGED
@@ -1609,6 +1609,14 @@ pygments = ">=2.13.0,<3.0.0"
|
|
1609 |
[package.extras]
|
1610 |
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
1611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1612 |
[[package]]
|
1613 |
name = "s3transfer"
|
1614 |
version = "0.6.0"
|
@@ -2196,7 +2204,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co
|
|
2196 |
[metadata]
|
2197 |
lock-version = "1.1"
|
2198 |
python-versions = "^3.10"
|
2199 |
-
content-hash = "
|
2200 |
|
2201 |
[metadata.files]
|
2202 |
altair = []
|
@@ -2428,6 +2436,7 @@ requests = []
|
|
2428 |
rfc3339-validator = []
|
2429 |
rfc3986-validator = []
|
2430 |
rich = []
|
|
|
2431 |
s3transfer = []
|
2432 |
scipy = []
|
2433 |
semver = []
|
|
|
1609 |
[package.extras]
|
1610 |
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
1611 |
|
1612 |
+
[[package]]
|
1613 |
+
name = "ruff"
|
1614 |
+
version = "0.0.264"
|
1615 |
+
description = "An extremely fast Python linter, written in Rust."
|
1616 |
+
category = "main"
|
1617 |
+
optional = false
|
1618 |
+
python-versions = ">=3.7"
|
1619 |
+
|
1620 |
[[package]]
|
1621 |
name = "s3transfer"
|
1622 |
version = "0.6.0"
|
|
|
2204 |
[metadata]
|
2205 |
lock-version = "1.1"
|
2206 |
python-versions = "^3.10"
|
2207 |
+
content-hash = "502949174f23054a4b450dfc0bb16df64c43d7d6c3e60d1adaf2835962223c32"
|
2208 |
|
2209 |
[metadata.files]
|
2210 |
altair = []
|
|
|
2436 |
rfc3339-validator = []
|
2437 |
rfc3986-validator = []
|
2438 |
rich = []
|
2439 |
+
ruff = []
|
2440 |
s3transfer = []
|
2441 |
scipy = []
|
2442 |
semver = []
|
pyproject.toml
CHANGED
@@ -14,6 +14,7 @@ torch = "^2.0.0"
|
|
14 |
sentencepiece = "^0.1.97"
|
15 |
tape-proteins = "^0.5"
|
16 |
matplotlib = "^3.7.1"
|
|
|
17 |
|
18 |
[tool.poetry.dev-dependencies]
|
19 |
pytest = "^7.2.2"
|
@@ -21,3 +22,9 @@ pytest = "^7.2.2"
|
|
21 |
[build-system]
|
22 |
requires = ["poetry-core>=1.0.0"]
|
23 |
build-backend = "poetry.core.masonry.api"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
sentencepiece = "^0.1.97"
|
15 |
tape-proteins = "^0.5"
|
16 |
matplotlib = "^3.7.1"
|
17 |
+
ruff = "^0.0.264"
|
18 |
|
19 |
[tool.poetry.dev-dependencies]
|
20 |
pytest = "^7.2.2"
|
|
|
22 |
[build-system]
|
23 |
requires = ["poetry-core>=1.0.0"]
|
24 |
build-backend = "poetry.core.masonry.api"
|
25 |
+
|
26 |
+
[tool.ruff]
|
27 |
+
line-length = 100
|
28 |
+
|
29 |
+
[tool.black]
|
30 |
+
line-length = 100
|
tests/test_attention.py
CHANGED
@@ -1,8 +1,13 @@
|
|
1 |
import torch
|
2 |
from Bio.PDB.Structure import Structure
|
3 |
|
4 |
-
from hexviz.attention import (
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
def test_get_structure():
|
@@ -12,10 +17,11 @@ def test_get_structure():
|
|
12 |
assert structure is not None
|
13 |
assert isinstance(structure, Structure)
|
14 |
|
|
|
15 |
def test_get_sequences():
|
16 |
pdb_id = "1AKE"
|
17 |
structure = get_structure(pdb_id)
|
18 |
-
|
19 |
sequences = get_sequences(structure)
|
20 |
|
21 |
assert sequences is not None
|
@@ -30,26 +36,29 @@ def test_get_attention_zymctrl():
|
|
30 |
result = get_attention("GGG", model_type=ModelType.ZymCTRL)
|
31 |
|
32 |
assert result is not None
|
33 |
-
assert result.shape == torch.Size([36,16,3,3])
|
|
|
34 |
|
35 |
def test_get_attention_zymctrl_long_chain():
|
36 |
-
structure = get_structure(pdb_code="6A5J")
|
37 |
|
38 |
sequences = get_sequences(structure)
|
39 |
|
40 |
result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)
|
41 |
|
42 |
assert result is not None
|
43 |
-
assert result.shape == torch.Size([36,16,13,13])
|
|
|
44 |
|
45 |
def test_get_attention_tape():
|
46 |
-
structure = get_structure(pdb_code="6A5J")
|
47 |
sequences = get_sequences(structure)
|
48 |
|
49 |
result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)
|
50 |
|
51 |
assert result is not None
|
52 |
-
assert result.shape == torch.Size([12,12,13,13])
|
|
|
53 |
|
54 |
def test_get_attention_prot_bert():
|
55 |
|
@@ -58,21 +67,19 @@ def test_get_attention_prot_bert():
|
|
58 |
assert result is not None
|
59 |
assert result.shape == torch.Size([30, 16, 3, 3])
|
60 |
|
|
|
61 |
def test_get_unidirection_avg_filtered():
|
62 |
# 1 head, 1 layer, 4 residues long attention tensor
|
63 |
-
attention= torch.tensor(
|
64 |
-
|
65 |
-
|
66 |
-
[4, 7, 9, 11]]]], dtype=torch.float32)
|
67 |
|
68 |
result = unidirectional_avg_filtered(attention, 0, 0, 0)
|
69 |
|
70 |
assert result is not None
|
71 |
assert len(result) == 10
|
72 |
|
73 |
-
attention = torch.tensor([[[[1, 2, 3],
|
74 |
-
[2, 5, 6],
|
75 |
-
[4, 7, 91]]]], dtype=torch.float32)
|
76 |
|
77 |
result = unidirectional_avg_filtered(attention, 0, 0, 0)
|
78 |
|
|
|
1 |
import torch
|
2 |
from Bio.PDB.Structure import Structure
|
3 |
|
4 |
+
from hexviz.attention import (
|
5 |
+
ModelType,
|
6 |
+
get_attention,
|
7 |
+
get_sequences,
|
8 |
+
get_structure,
|
9 |
+
unidirectional_avg_filtered,
|
10 |
+
)
|
11 |
|
12 |
|
13 |
def test_get_structure():
|
|
|
17 |
assert structure is not None
|
18 |
assert isinstance(structure, Structure)
|
19 |
|
20 |
+
|
21 |
def test_get_sequences():
|
22 |
pdb_id = "1AKE"
|
23 |
structure = get_structure(pdb_id)
|
24 |
+
|
25 |
sequences = get_sequences(structure)
|
26 |
|
27 |
assert sequences is not None
|
|
|
36 |
result = get_attention("GGG", model_type=ModelType.ZymCTRL)
|
37 |
|
38 |
assert result is not None
|
39 |
+
assert result.shape == torch.Size([36, 16, 3, 3])
|
40 |
+
|
41 |
|
42 |
def test_get_attention_zymctrl_long_chain():
|
43 |
+
structure = get_structure(pdb_code="6A5J") # 13 residues long
|
44 |
|
45 |
sequences = get_sequences(structure)
|
46 |
|
47 |
result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)
|
48 |
|
49 |
assert result is not None
|
50 |
+
assert result.shape == torch.Size([36, 16, 13, 13])
|
51 |
+
|
52 |
|
53 |
def test_get_attention_tape():
|
54 |
+
structure = get_structure(pdb_code="6A5J") # 13 residues long
|
55 |
sequences = get_sequences(structure)
|
56 |
|
57 |
result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)
|
58 |
|
59 |
assert result is not None
|
60 |
+
assert result.shape == torch.Size([12, 12, 13, 13])
|
61 |
+
|
62 |
|
63 |
def test_get_attention_prot_bert():
|
64 |
|
|
|
67 |
assert result is not None
|
68 |
assert result.shape == torch.Size([30, 16, 3, 3])
|
69 |
|
70 |
+
|
71 |
def test_get_unidirection_avg_filtered():
|
72 |
# 1 head, 1 layer, 4 residues long attention tensor
|
73 |
+
attention = torch.tensor(
|
74 |
+
[[[[1, 2, 3, 4], [2, 5, 6, 7], [3, 6, 8, 9], [4, 7, 9, 11]]]], dtype=torch.float32
|
75 |
+
)
|
|
|
76 |
|
77 |
result = unidirectional_avg_filtered(attention, 0, 0, 0)
|
78 |
|
79 |
assert result is not None
|
80 |
assert len(result) == 10
|
81 |
|
82 |
+
attention = torch.tensor([[[[1, 2, 3], [2, 5, 6], [4, 7, 91]]]], dtype=torch.float32)
|
|
|
|
|
83 |
|
84 |
result = unidirectional_avg_filtered(attention, 0, 0, 0)
|
85 |
|
tests/test_models.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
|
3 |
|
4 |
from hexviz.models import get_zymctrl
|
@@ -13,4 +12,4 @@ def test_get_zymctrl():
|
|
13 |
tokenizer, model = result
|
14 |
|
15 |
assert isinstance(tokenizer, GPT2TokenizerFast)
|
16 |
-
assert isinstance(model, GPT2LMHeadModel)
|
|
|
|
|
1 |
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
|
2 |
|
3 |
from hexviz.models import get_zymctrl
|
|
|
12 |
tokenizer, model = result
|
13 |
|
14 |
assert isinstance(tokenizer, GPT2TokenizerFast)
|
15 |
+
assert isinstance(model, GPT2LMHeadModel)
|