import time from io import StringIO from urllib import request import requests import streamlit as st import torch from Bio.PDB import PDBParser, Polypeptide, Structure from Bio.PDB.Residue import Residue from hexviz.ec_number import ECNumber from hexviz.models import ModelType, get_prot_bert, get_prot_t5, get_tape_bert, get_zymctrl def get_structure(pdb_code: str) -> Structure: """ Get structure from PDB """ pdb_url = f"https://files.rcsb.org/download/{pdb_code}.pdb" pdb_data = request.urlopen(pdb_url).read().decode("utf-8") file = StringIO(pdb_data) parser = PDBParser() structure = parser.get_structure(pdb_code, file) return structure def get_pdb_file(pdb_code: str) -> Structure: """ Get structure from PDB """ pdb_url = f"https://files.rcsb.org/download/{pdb_code}.pdb" pdb_data = request.urlopen(pdb_url).read().decode("utf-8") file = StringIO(pdb_data) return file @st.cache def get_pdb_from_seq(sequence: str) -> str | None: """ Get structure from sequence """ url = "https://api.esmatlas.com/foldSequence/v1/pdb/" retries = 0 pdb_str = None while retries < 3 and pdb_str is None: response = requests.post(url, data=sequence) pdb_str = response.text if pdb_str == "INTERNAL SERVER ERROR": retries += 1 time.sleep(0.1) pdb_str = None return pdb_str def get_chains(structure: Structure) -> list[str]: """ Get list of chains in a structure """ chains = [] for model in structure: for chain in model.get_chains(): chains.append(chain.id) return chains def res_to_1letter(residues: list[Residue]) -> str: """ Get single letter sequence from a list or Residues Residues not in the standard 20 amino acids are replaced with X """ res_names = [residue.get_resname() for residue in residues] residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), res_names) return "".join(list(residues_single_letter)) def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]: lines = sequence.split("\n") cleaned_sequence = "".join(line.upper() for line in lines if not line.startswith(">")) cleaned_sequence = cleaned_sequence.replace(" ", "") valid_residues = set(Polypeptide.protein_letters_3to1.values()) residues_in_sequence = set(cleaned_sequence) # Check if the sequence exceeds the max allowed length max_sequence_length = 400 if len(cleaned_sequence) > max_sequence_length: error_message = ( f"Sequence exceeds the max allowed length of {max_sequence_length} characters" ) return cleaned_sequence, error_message illegal_residues = residues_in_sequence - valid_residues if illegal_residues: illegal_residues_str = ", ".join(illegal_residues) error_message = f"Sequence contains illegal residues: {illegal_residues_str}" return cleaned_sequence, error_message else: return cleaned_sequence, None def remove_tokens(attentions, tokens, tokens_to_remove): indices_to_remove = [i for i, token in enumerate(tokens) if token in tokens_to_remove] # Remove rows and columns corresponding to special tokens and periods for idx in sorted(indices_to_remove, reverse=True): attentions = torch.cat((attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2) attentions = torch.cat((attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3) return attentions @st.cache def get_attention( sequence: str, model_type: ModelType = ModelType.TAPE_BERT, remove_special_tokens: bool = True, ec_number: str = None, ): """ Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights and the sequence of tokenes that the attention tensor corresponds to """ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if model_type == ModelType.TAPE_BERT: tokenizer, model = get_tape_bert() token_idxs = tokenizer.encode(sequence).tolist() inputs = torch.tensor(token_idxs).unsqueeze(0) with torch.no_grad(): attentions = model(inputs)[-1] tokenized_sequence = tokenizer.convert_ids_to_tokens(token_idxs) if remove_special_tokens: # Remove attention from (first) and (last) token attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions] tokenized_sequence = tokenized_sequence[1:-1] attentions = torch.stack([attention.squeeze(0) for attention in attentions]) elif model_type == ModelType.ZymCTRL: tokenizer, model = get_zymctrl() if ec_number: sequence = f"{ec_number}{sequence}" inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device) attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(device) with torch.no_grad(): outputs = model(inputs, attention_mask=attention_mask, output_attentions=True) attentions = outputs.attentions tokenized_sequence = tokenizer.convert_ids_to_tokens(tokenizer.encode(sequence)) if ec_number and remove_special_tokens: # Remove attention to special tokens and periods separating EC number components tokens_to_remove = [".", "", "", "", ""] attentions = [ remove_tokens(attention, tokenized_sequence, tokens_to_remove) for attention in attentions ] tokenized_sequence = [ token for token in tokenized_sequence if token not in tokens_to_remove ] # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res]) attention_squeezed = [torch.squeeze(attention) for attention in attentions] # ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res] attention_stacked = torch.stack([attention for attention in attention_squeezed]) attentions = attention_stacked elif model_type == ModelType.PROT_BERT: tokenizer, model = get_prot_bert() sequence_separated = " ".join(sequence) token_idxs = tokenizer.encode(sequence_separated) inputs = torch.tensor(token_idxs).unsqueeze(0).to(device) with torch.no_grad(): attentions = model(inputs, output_attentions=True)[-1] tokenized_sequence = tokenizer.convert_ids_to_tokens(token_idxs) if remove_special_tokens: # Remove attention from (first) and (last) token attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions] tokenized_sequence = tokenized_sequence[1:-1] attentions = torch.stack([attention.squeeze(0) for attention in attentions]) elif model_type == ModelType.PROT_T5: tokenizer, model = get_prot_t5() sequence_separated = " ".join(sequence) token_idxs = tokenizer.encode(sequence_separated) inputs = torch.tensor(token_idxs).unsqueeze(0).to(device) with torch.no_grad(): attentions = model(inputs, output_attentions=True)[-1] tokenized_sequence = tokenizer.convert_ids_to_tokens(token_idxs) if remove_special_tokens: # Remove attention to (last) token attentions = [attention[:, :, :-1, :-1] for attention in attentions] tokenized_sequence = tokenized_sequence[:-1] attentions = torch.stack([attention.squeeze(0) for attention in attentions]) else: raise ValueError(f"Model {model_type} not supported") # Transfer to CPU to avoid issues with streamlit caching return attentions.cpu(), tokenized_sequence def unidirectional_avg_filtered(attention, layer, head, threshold): num_layers, num_heads, seq_len, _ = attention.shape attention_head = attention[layer, head] unidirectional_avg_for_head = [] for i in range(seq_len): for j in range(i, seq_len): # Attention matrices for BERT models are asymetric. # Bidirectional attention is represented by the average of the two values sum = attention_head[i, j].item() + attention_head[j, i].item() avg = sum / 2 if avg >= threshold: unidirectional_avg_for_head.append((avg, i, j)) return unidirectional_avg_for_head # Passing the pdb_str here is a workaround for streamlit caching # where I need the input to be hashable and not changing # The ideal would be to pass in the structure directly, not parsing # Thist twice. If streamlit is upgaded to past 0.17 this can be # fixed. @st.cache(show_spinner=False) def get_attention_pairs( pdb_str: str, layer: int, head: int, chain_ids: list[str] | None, threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT, top_n: int = 2, ec_numbers: list[list[ECNumber]] | None = None, ): """ Note: All residue indexes returned are 0 indexed """ structure = PDBParser().get_structure("pdb", StringIO(pdb_str)) if chain_ids: chains = [ch for ch in structure.get_chains() if ch.id in chain_ids] else: chains = list(structure.get_chains()) # Chains are treated at lists of residues to make indexing easier # and to avoid troubles with residues in PDB files not having a consistent # start index chain_ids = [chain.id for chain in chains] chains = [[res for res in chain.get_residues()] for chain in chains] attention_pairs = [] top_residues = [] ec_tag_length = 4 def is_tag(x): return x < ec_tag_length for i, chain in enumerate(chains): ec_number = ec_numbers[i] if ec_numbers else None ec_string = ".".join([ec.number for ec in ec_number]) if ec_number else "" sequence = res_to_1letter(chain) attention, _ = get_attention(sequence=sequence, model_type=model_type, ec_number=ec_string) attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold) # Store sum of attention in to a resiue (from the unidirectional attention) residue_attention = {} for attn_value, res_1, res_2 in attention_unidirectional: try: if not ec_number: coord_1 = chain[res_1]["CA"].coord.tolist() coord_2 = chain[res_2]["CA"].coord.tolist() else: if is_tag(res_1): coord_1 = ec_number[res_1].coordinate else: coord_1 = chain[res_1 - ec_tag_length]["CA"].coord.tolist() if is_tag(res_2): coord_2 = ec_number[res_2].coordinate else: coord_2 = chain[res_2 - ec_tag_length]["CA"].coord.tolist() except KeyError: continue attention_pairs.append((attn_value, coord_1, coord_2)) if not ec_number: residue_attention[res_1] = residue_attention.get(res_1, 0) + attn_value residue_attention[res_2] = residue_attention.get(res_2, 0) + attn_value else: for res in [res_1, res_2]: if not is_tag(res): residue_attention[res - ec_tag_length] = ( residue_attention.get(res - ec_tag_length, 0) + attn_value ) if not ec_number: attention_into_res = attention[layer, head].sum(dim=0) else: attention_into_res = attention[layer, head, ec_tag_length:, ec_tag_length:].sum(dim=0) top_n_values, top_n_indexes = torch.topk(attention_into_res, top_n) for res, attn_sum in zip(top_n_indexes, top_n_values): fraction_of_total_attention = attn_sum.item() / len(sequence) top_residues.append((fraction_of_total_attention, chain_ids[i], res.item())) return attention_pairs, top_residues