Glydentify / app.py
Aarya Venkat
Update -- need to add new model
d1ca73b
import pandas as pd
from IPython.display import clear_output
import torch
from transformers import EsmForSequenceClassification, AdamW, AutoTokenizer
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import numpy as np
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib
matplotlib.use('Agg') # Use the non-interactive Agg backend
import matplotlib.pyplot as plt
import pickle
import torch.nn.functional as F
import gradio as gr
import io
from PIL import Image
import Bio
from Bio import SeqIO
from Bio.Blast import NCBIXML
import subprocess
import zipfile
import os
GTA_fam_dict = {
0: "GT116",
1: "GT12",
2: "GT13",
3: "GT14",
4: "GT15",
5: "GT16",
6: "GT17",
7: "GT2-clade1",
8: "GT2-clade2",
9: "GT2-clade3",
10: "GT2-clade4",
11: "GT2-clade5",
12: "GT2-related",
13: "GT21",
14: "GT24",
15: "GT25",
16: "GT27",
17: "GT31",
18: "GT32",
19: "GT34",
20: "GT40",
21: "GT43",
22: "GT45",
23: "GT49",
24: "GT54",
25: "GT55",
26: "GT6",
27: "GT60",
28: "GT62",
29: "GT64",
30: "GT67",
31: "GT7",
32: "GT75",
33: "GT77",
34: "GT78",
35: "GT8",
36: "GT81",
37: "GT82",
38: "GT84",
39: "GT88",
40: "GT92"
}
GTA_don_dict = {
0: "N-Acetyl Galactosamine",
1: "N-Acetyl Glucosamine",
2: "Arabinose",
3: "Galactose",
4: "Galacturonic Acid",
5: "Glucose",
6: "Glucuronic Acid",
7: "Mannose",
8: "Rhamnose",
9: "Xylose"
}
GTB_fam_dict = {
0: "GT1",
1: "GT10",
2: "GT104",
3: "GT11",
4: "GT18",
5: "GT19",
6: "GT20",
7: "GT23",
8: "GT28",
9: "GT3",
10: "GT30",
11: "GT35",
12: "GT37",
13: "GT38",
14: "GT4",
15: "GT41",
16: "GT5",
17: "GT52",
18: "GT63",
19: "GT65",
20: "GT68",
21: "GT70",
22: "GT72",
23: "GT80",
24: "GT9",
25: "GT90",
26: "GT99"
}
GTB_don_dict = {
0: "Fucose",
1: "Galactose",
2: "N-Acetyl Galactosamine",
3: "Glucuronic Acid",
4: "N-Acetyl Glucosamine",
5: "Glucose",
6: "Mannose",
7: "Other",
8: "Xylose"
}
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") #facebook/esm2_t33_650M_UR50D
glycosyltransferase_db = {
"GT40" : {'CAZy Name': 'GT40', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT40.html'},
"GT16" : {'CAZy Name': 'GT16', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT16.html'},
"GT27" : {'CAZy Name': 'GT27', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT27.html'},
"GT55" : {'CAZy Name': 'GT55', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT55.html'},
"GT25" : {'CAZy Name': 'GT25', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT25.html'},
"GT2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT84" : {'CAZy Name': 'GT84', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT84.html'},
"GT13" : {'CAZy Name': 'GT13', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT13.html'},
"GT67" : {'CAZy Name': 'GT67', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT67.html'},
"GT82" : {'CAZy Name': 'GT82', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT82.html'},
"GT24" : {'CAZy Name': 'GT24', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT24.html'},
"GT81" : {'CAZy Name': 'GT81', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT81.html'},
"GT49" : {'CAZy Name': 'GT49', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT49.html'},
"GT34" : {'CAZy Name': 'GT34', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT34.html'},
"GT45" : {'CAZy Name': 'GT45', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT45.html'},
"GT32" : {'CAZy Name': 'GT32', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT32.html'},
"GT88" : {'CAZy Name': 'GT88', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT88.html'},
"GT21" : {'CAZy Name': 'GT21', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT21.html'},
"GT54" : {'CAZy Name': 'GT54', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT54.html'},
"GT6" : {'CAZy Name': 'GT6 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT6.html' },
"GT7" : {'CAZy Name': 'GT7 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT7.html' },
"GT64" : {'CAZy Name': 'GT64', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT64.html'},
"GT78" : {'CAZy Name': 'GT78', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT78.html'},
"GT12" : {'CAZy Name': 'GT12', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT12.html'},
"GT31" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'},
"GT62" : {'CAZy Name': 'GT62', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '3 ', 'More Info': 'http://www.cazy.org/GT62.html'},
"GT8" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' },
"GT15" : {'CAZy Name': 'GT15', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT15.html'},
"GT43" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
"GT60" : {'CAZy Name': 'GT60', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT60.html'},
"GT14" : {'CAZy Name': 'GT14', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT14.html'},
"GT17" : {'CAZy Name': 'GT17', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT17.html'},
"GT77" : {'CAZy Name': 'GT77', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT77.html'},
"GT75" : {'CAZy Name': 'GT75', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT75.html'},
}
def parse_blast_output_for_best_evalue(output_file):
with open(output_file) as result_handle:
blast_record = NCBIXML.read(result_handle)
if len(blast_record.alignments) == 0:
# Handle the case where no alignments are found
# You might return a high e-value or None to indicate no match
return None
best_hit = blast_record.alignments[0]
best_evalue = best_hit.hsps[0].expect
print(best_evalue)
return best_evalue
def run_local_blast(sequence, database):
# Temporarily save the query sequence to a file
query_file = "temp_query.fasta"
with open(query_file, "w") as file:
file.write(">Query\n" + sequence)
# Specify the output file for BLAST results
output_file = "blast_results.xml"
# Construct the BLAST command
blast_cmd = [
"blastp",
"-query", query_file,
"-db", database,
"-out", output_file,
"-outfmt", "5", # Output format 5 is XML
"-evalue", "1e-2" # Set your desired E-value threshold here
]
# Execute the BLAST search
subprocess.run(blast_cmd, check=True)
# Parse the BLAST output to find the best E-value
best_evalue = parse_blast_output_for_best_evalue(output_file)
# Clean up temporary files
os.remove(query_file)
os.remove(output_file)
return best_evalue
def get_family_info(family_name):
family_info = glycosyltransferase_db.get(family_name, {})
output = ""
for key, value in family_info.items():
if key == "more_info":
output += "**{}:**".format(key.title().replace("_", " ")) + "\n"
for link in value:
output += "[{}]({}) ".format(link, link)
else:
output += "**{}:** {} ".format(key.title().replace("_", " "), value)
return output
def fig_to_img(fig):
"""Converts a matplotlib figure to a PIL Image and returns it"""
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
return img
def preprocess_protein_sequence(protein_fasta):
lines = protein_fasta.split('\n')
headers = [line for line in lines if line.startswith('>')]
if len(headers) > 1:
return None, None, None, "Multiple fasta sequences detected. Please upload a fasta file with only one sequence."
protein_sequence = ''.join(line for line in lines if not line.startswith('>'))
valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy")
# Check if every character in the sequence is in the set of valid characters.
if any(char.upper() not in valid_characters for char in protein_sequence):
return None, None, None, "Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids."
print("Running Blast.")
gta_db_path = "blast_data/GTA/GTA.db"
gtb_db_path = "blast_data/GTB/GTB.db"
evalue_gta = run_local_blast(protein_sequence, gta_db_path)
evalue_gta = evalue_gta if evalue_gta is not None else 1e+100
evalue_gtb = run_local_blast(protein_sequence, gtb_db_path)
evalue_gtb = evalue_gtb if evalue_gtb is not None else 1e+100
print("E-value GT-A:", evalue_gta, "E-value GT-B:", evalue_gtb)
print("Blast finished running. Checking sequence against known data.")
# Determine which models to use based on the best E-value
model_fam = "GTA_fam.pth" if evalue_gta < evalue_gtb else "GTB_fam.pth"
model_don = "GTA_don.pth" if evalue_gta < evalue_gtb else "GTB_don.pth"
print("Selected model for family:", model_fam, "and donor:", model_don)
# Adjust your existing condition to check if both E-values exceed the threshold
if evalue_gta > 1e-2 and evalue_gtb > 1e-2:
# If both E-values are above the threshold, it suggests the sequence does not match well with either database
return None, None, None, "**Warning:** The sequence does not appear to be a GT-A or GT-B. Please ensure you are submitting a sequence from these families."
return protein_sequence, model_fam, model_don, None
def process_family_sequence(protein_sequence, modelfam, label_dict):
encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
input_idsfam = encoded_input["input_ids"]
attention_maskfam = encoded_input["attention_mask"]
with torch.no_grad():
outputfam = modelfam(input_idsfam, attention_mask=attention_maskfam)
logitsfam = outputfam.logits
probabilitiesfam = F.softmax(logitsfam, dim=1)
_, predicted_labelsfam = torch.max(logitsfam, dim=1)
predicted_label_index_fam = predicted_labelsfam.item() # Assuming single sample prediction
decoded_label_fam = label_dict.get(predicted_label_index_fam, "Unknown Label") # Decoding label using the dictionary
family_info = get_family_info(decoded_label_fam)
figfam = plt.figure(figsize=(10, 5))
# probabilitiesfam_flat = probabilitiesfam.squeeze().tolist() # Flatten probabilities
# Extract and sort top 5 label probabilities
top5_probs, top5_labels = torch.topk(probabilitiesfam, 5)
top5_labels = top5_labels.squeeze().tolist()
top5_decoded_labels = [label_dict.get(label, "Unknown") for label in top5_labels]
# For debugging
print("Top 5 labels:", top5_labels)
print("Available keys in label_dict:", label_dict.keys())
y_posfam = np.arange(len(top5_decoded_labels))
plt.barh(y_posfam, [prob * 100 for prob in top5_probs.squeeze().tolist()], align='center', alpha=0.5)
plt.yticks(y_posfam, top5_decoded_labels)
plt.xlabel('Probability (%)')
plt.title('Top 5 Family Class Probabilities')
plt.xlim(0, 100)
plt.close(figfam)
img = fig_to_img(figfam)
if len(protein_sequence) < 100:
return decoded_label_fam, img, None, "**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}"
return decoded_label_fam, img, None, family_info
def process_donor_sequence(protein_sequence, modeldon, label_dict):
encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
input_idsdon = encoded_input["input_ids"]
attention_maskdon = encoded_input["attention_mask"]
with torch.no_grad():
outputdon = modeldon(input_idsdon, attention_mask=attention_maskdon)
logitsdon = outputdon.logits
probabilitiesdon = F.softmax(logitsdon, dim=1)
_, predicted_labelsdon = torch.max(logitsdon, dim=1)
predicted_label_index_don = predicted_labelsdon.item() # Assuming single sample prediction
decoded_label_don = label_dict.get(predicted_label_index_don, "Unknown Label") # Decoding label using the dictionary
figdon = plt.figure(figsize=(10, 5))
probabilitiesdon_flat = probabilitiesdon.squeeze().tolist() # Flatten probabilities
# Extract and sort top 5 label probabilities
top3_probs, top3_labels = torch.topk(probabilitiesdon, 3)
top3_labels = top3_labels.squeeze().tolist()
top3_decoded_labels = [label_dict.get(label, "Unknown") for label in top3_labels]
y_posdon = np.arange(len(top3_decoded_labels))
plt.barh(y_posdon, [prob * 100 for prob in top3_probs.squeeze().tolist()], align='center', alpha=0.5)
plt.yticks(y_posdon, top3_decoded_labels)
plt.xlabel('Probability (%)')
plt.title('Top 3 Donor Class Probabilities')
plt.xlim(0, 100)
plt.close(figdon)
img = fig_to_img(figdon)
if len(protein_sequence) < 100:
return decoded_label_don, img, None, "**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}"
return decoded_label_don, img, None
def main_function_single(sequence):
# Initial preprocessing including BLAST-based model selection
protein_sequence, model_fam_path, model_don_path, error_msg = preprocess_protein_sequence(sequence)
if error_msg:
print(error_msg)
return None, None, error_msg, None, None
model_config = {
"GTA_fam.pth": {"num_labels": 41, "label_dict": GTA_fam_dict},
"GTB_fam.pth": {"num_labels": 27, "label_dict": GTB_fam_dict},
"GTA_don.pth": {"num_labels": 10, "label_dict": GTA_don_dict},
"GTB_don.pth": {"num_labels": 9, "label_dict": GTB_don_dict},
}
# Load the model for family classification
config_fam = model_config[model_fam_path]
model_fam = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=config_fam["num_labels"])
model_fam.load_state_dict(torch.load(model_fam_path, map_location=torch.device('cpu')), strict=False)
model_fam.eval()
model_fam.to('cpu')
# Load the model for donor classification
config_don = model_config[model_don_path]
model_don = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=config_don["num_labels"])
model_don.load_state_dict(torch.load(model_don_path, map_location=torch.device('cpu')), strict=False)
model_don.eval()
model_don.to('cpu')
print(config_fam["label_dict"])
# Pass the label dictionary along with the model to the processing functions
family_label, family_img, _, family_info = process_family_sequence(protein_sequence, model_fam, config_fam["label_dict"])
donor_label, donor_img, _ = process_donor_sequence(protein_sequence, model_don, config_don["label_dict"])
return family_label, family_img, family_info, donor_label, donor_img
prediction_imagefam = gr.outputs.Image(type='pil', label="Family prediction graph")
prediction_imagedonor = gr.outputs.Image(type='pil', label="Donor prediction graph")
with gr.Blocks() as app:
gr.Markdown("# Glydentify (alpha v0.5)")
with gr.Tab("Single Sequence Prediction"):
with gr.Row().style(equal_height=True):
with gr.Column():
sequence = gr.inputs.Textbox(lines=16, placeholder='Enter Protein Sequence Here...', label="Protein Sequence")
# explanation_checkbox = gr.inputs.Checkbox(label="Show Explanation", default=False)
with gr.Column():
with gr.Accordion("Example:"):
gr.Markdown("""
\>sp|Q9Y5Z6|B3GT1_HUMAN Beta-1,3-galactosyltransferase 1 OS=Homo sapiens OX=9606 GN=B3GALT1 PE=1 SV=1
MASKVSCLYVLTVVCWASALWYLSITRPTSSYTGSKPFSHLTVARKNFTFGNIRTRPINPHSFEFLINEPNKCEKNIPFLVILIST
THKEFDARQAIRETWGDENNFKGIKIATLFLLGKNADPVLNQMVEQESQIFHDIIVEDFIDSYHNLTLKTLMGMRWVATFCSK
AKYVMKTDSDIFVNMDNLIYKLLKPSTKPRRRYFTGYVINGGPIRDVRSKWYMPRDLYPDSNYPPFCSGTGYIFSADVAELIYK
TSLHTRLLHLEDVYVGLCLRKLGIHPFQNSGFNHWKMAYSLCRYRRVITVHQISPEEMHRIWNDMSSKKHLRC
""")
family_prediction = gr.outputs.Textbox(label="Predicted family")
donor_prediction = gr.outputs.Textbox(label="Predicted donor")
info_markdown = gr.Markdown()
# Predict and Clear buttons
with gr.Row().style(equal_height=True):
with gr.Column():
predict_button = gr.Button("Predict")
predict_button.click(main_function_single, inputs=[sequence],
outputs=[family_prediction, prediction_imagefam, info_markdown,
donor_prediction, prediction_imagedonor])
# Family & Donor Section
with gr.Row().style(equal_height=True):
with gr.Column():
with gr.Accordion("Family Prediction:"):
prediction_imagefam.render() # = gr.outputs.Image(type='pil', label="Family prediction graph")
with gr.Column():
with gr.Accordion("Donor Prediction:"):
prediction_imagedonor.render() # = gr.outputs.Image(type='pil', label="Donor prediction graph")
app.launch(show_error=True)