import pandas as pd from IPython.display import clear_output import torch from transformers import EsmForSequenceClassification, AdamW, AutoTokenizer from torch.utils.data import DataLoader, TensorDataset, random_split from sklearn.preprocessing import LabelEncoder from tqdm import tqdm import numpy as np import seaborn as sns from sklearn.model_selection import train_test_split import matplotlib matplotlib.use('Agg') # Use the non-interactive Agg backend import matplotlib.pyplot as plt import pickle import torch.nn.functional as F import gradio as gr import io from PIL import Image import Bio from Bio import SeqIO from Bio.Blast import NCBIXML import subprocess import zipfile import os GTA_fam_dict = { 0: "GT116", 1: "GT12", 2: "GT13", 3: "GT14", 4: "GT15", 5: "GT16", 6: "GT17", 7: "GT2-clade1", 8: "GT2-clade2", 9: "GT2-clade3", 10: "GT2-clade4", 11: "GT2-clade5", 12: "GT2-related", 13: "GT21", 14: "GT24", 15: "GT25", 16: "GT27", 17: "GT31", 18: "GT32", 19: "GT34", 20: "GT40", 21: "GT43", 22: "GT45", 23: "GT49", 24: "GT54", 25: "GT55", 26: "GT6", 27: "GT60", 28: "GT62", 29: "GT64", 30: "GT67", 31: "GT7", 32: "GT75", 33: "GT77", 34: "GT78", 35: "GT8", 36: "GT81", 37: "GT82", 38: "GT84", 39: "GT88", 40: "GT92" } GTA_don_dict = { 0: "N-Acetyl Galactosamine", 1: "N-Acetyl Glucosamine", 2: "Arabinose", 3: "Galactose", 4: "Galacturonic Acid", 5: "Glucose", 6: "Glucuronic Acid", 7: "Mannose", 8: "Rhamnose", 9: "Xylose" } GTB_fam_dict = { 0: "GT1", 1: "GT10", 2: "GT104", 3: "GT11", 4: "GT18", 5: "GT19", 6: "GT20", 7: "GT23", 8: "GT28", 9: "GT3", 10: "GT30", 11: "GT35", 12: "GT37", 13: "GT38", 14: "GT4", 15: "GT41", 16: "GT5", 17: "GT52", 18: "GT63", 19: "GT65", 20: "GT68", 21: "GT70", 22: "GT72", 23: "GT80", 24: "GT9", 25: "GT90", 26: "GT99" } GTB_don_dict = { 0: "Fucose", 1: "Galactose", 2: "N-Acetyl Galactosamine", 3: "Glucuronic Acid", 4: "N-Acetyl Glucosamine", 5: "Glucose", 6: "Mannose", 7: "Other", 8: "Xylose" } tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") #facebook/esm2_t33_650M_UR50D glycosyltransferase_db = { "GT40" : {'CAZy Name': 'GT40', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT40.html'}, "GT16" : {'CAZy Name': 'GT16', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT16.html'}, "GT27" : {'CAZy Name': 'GT27', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT27.html'}, "GT55" : {'CAZy Name': 'GT55', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT55.html'}, "GT25" : {'CAZy Name': 'GT25', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT25.html'}, "GT2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' }, "GT84" : {'CAZy Name': 'GT84', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT84.html'}, "GT13" : {'CAZy Name': 'GT13', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT13.html'}, "GT67" : {'CAZy Name': 'GT67', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT67.html'}, "GT82" : {'CAZy Name': 'GT82', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT82.html'}, "GT24" : {'CAZy Name': 'GT24', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT24.html'}, "GT81" : {'CAZy Name': 'GT81', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT81.html'}, "GT49" : {'CAZy Name': 'GT49', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT49.html'}, "GT34" : {'CAZy Name': 'GT34', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT34.html'}, "GT45" : {'CAZy Name': 'GT45', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT45.html'}, "GT32" : {'CAZy Name': 'GT32', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT32.html'}, "GT88" : {'CAZy Name': 'GT88', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT88.html'}, "GT21" : {'CAZy Name': 'GT21', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT21.html'}, "GT54" : {'CAZy Name': 'GT54', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT54.html'}, "GT6" : {'CAZy Name': 'GT6 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT6.html' }, "GT7" : {'CAZy Name': 'GT7 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT7.html' }, "GT64" : {'CAZy Name': 'GT64', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT64.html'}, "GT78" : {'CAZy Name': 'GT78', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT78.html'}, "GT12" : {'CAZy Name': 'GT12', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT12.html'}, "GT31" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'}, "GT62" : {'CAZy Name': 'GT62', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '3 ', 'More Info': 'http://www.cazy.org/GT62.html'}, "GT8" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' }, "GT15" : {'CAZy Name': 'GT15', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT15.html'}, "GT43" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'}, "GT60" : {'CAZy Name': 'GT60', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT60.html'}, "GT14" : {'CAZy Name': 'GT14', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT14.html'}, "GT17" : {'CAZy Name': 'GT17', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT17.html'}, "GT77" : {'CAZy Name': 'GT77', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT77.html'}, "GT75" : {'CAZy Name': 'GT75', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT75.html'}, } def parse_blast_output_for_best_evalue(output_file): with open(output_file) as result_handle: blast_record = NCBIXML.read(result_handle) if len(blast_record.alignments) == 0: # Handle the case where no alignments are found # You might return a high e-value or None to indicate no match return None best_hit = blast_record.alignments[0] best_evalue = best_hit.hsps[0].expect print(best_evalue) return best_evalue def run_local_blast(sequence, database): # Temporarily save the query sequence to a file query_file = "temp_query.fasta" with open(query_file, "w") as file: file.write(">Query\n" + sequence) # Specify the output file for BLAST results output_file = "blast_results.xml" # Construct the BLAST command blast_cmd = [ "blastp", "-query", query_file, "-db", database, "-out", output_file, "-outfmt", "5", # Output format 5 is XML "-evalue", "1e-2" # Set your desired E-value threshold here ] # Execute the BLAST search subprocess.run(blast_cmd, check=True) # Parse the BLAST output to find the best E-value best_evalue = parse_blast_output_for_best_evalue(output_file) # Clean up temporary files os.remove(query_file) os.remove(output_file) return best_evalue def get_family_info(family_name): family_info = glycosyltransferase_db.get(family_name, {}) output = "" for key, value in family_info.items(): if key == "more_info": output += "**{}:**".format(key.title().replace("_", " ")) + "\n" for link in value: output += "[{}]({}) ".format(link, link) else: output += "**{}:** {} ".format(key.title().replace("_", " "), value) return output def fig_to_img(fig): """Converts a matplotlib figure to a PIL Image and returns it""" buf = io.BytesIO() fig.savefig(buf, format='png', bbox_inches='tight') buf.seek(0) img = Image.open(buf) return img def preprocess_protein_sequence(protein_fasta): lines = protein_fasta.split('\n') headers = [line for line in lines if line.startswith('>')] if len(headers) > 1: return None, None, None, "Multiple fasta sequences detected. Please upload a fasta file with only one sequence." protein_sequence = ''.join(line for line in lines if not line.startswith('>')) valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy") # Check if every character in the sequence is in the set of valid characters. if any(char.upper() not in valid_characters for char in protein_sequence): return None, None, None, "Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids." print("Running Blast.") gta_db_path = "blast_data/GTA/GTA.db" gtb_db_path = "blast_data/GTB/GTB.db" evalue_gta = run_local_blast(protein_sequence, gta_db_path) evalue_gta = evalue_gta if evalue_gta is not None else 1e+100 evalue_gtb = run_local_blast(protein_sequence, gtb_db_path) evalue_gtb = evalue_gtb if evalue_gtb is not None else 1e+100 print("E-value GT-A:", evalue_gta, "E-value GT-B:", evalue_gtb) print("Blast finished running. Checking sequence against known data.") # Determine which models to use based on the best E-value model_fam = "GTA_fam.pth" if evalue_gta < evalue_gtb else "GTB_fam.pth" model_don = "GTA_don.pth" if evalue_gta < evalue_gtb else "GTB_don.pth" print("Selected model for family:", model_fam, "and donor:", model_don) # Adjust your existing condition to check if both E-values exceed the threshold if evalue_gta > 1e-2 and evalue_gtb > 1e-2: # If both E-values are above the threshold, it suggests the sequence does not match well with either database return None, None, None, "**Warning:** The sequence does not appear to be a GT-A or GT-B. Please ensure you are submitting a sequence from these families." return protein_sequence, model_fam, model_don, None def process_family_sequence(protein_sequence, modelfam, label_dict): encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt") input_idsfam = encoded_input["input_ids"] attention_maskfam = encoded_input["attention_mask"] with torch.no_grad(): outputfam = modelfam(input_idsfam, attention_mask=attention_maskfam) logitsfam = outputfam.logits probabilitiesfam = F.softmax(logitsfam, dim=1) _, predicted_labelsfam = torch.max(logitsfam, dim=1) predicted_label_index_fam = predicted_labelsfam.item() # Assuming single sample prediction decoded_label_fam = label_dict.get(predicted_label_index_fam, "Unknown Label") # Decoding label using the dictionary family_info = get_family_info(decoded_label_fam) figfam = plt.figure(figsize=(10, 5)) # probabilitiesfam_flat = probabilitiesfam.squeeze().tolist() # Flatten probabilities # Extract and sort top 5 label probabilities top5_probs, top5_labels = torch.topk(probabilitiesfam, 5) top5_labels = top5_labels.squeeze().tolist() top5_decoded_labels = [label_dict.get(label, "Unknown") for label in top5_labels] # For debugging print("Top 5 labels:", top5_labels) print("Available keys in label_dict:", label_dict.keys()) y_posfam = np.arange(len(top5_decoded_labels)) plt.barh(y_posfam, [prob * 100 for prob in top5_probs.squeeze().tolist()], align='center', alpha=0.5) plt.yticks(y_posfam, top5_decoded_labels) plt.xlabel('Probability (%)') plt.title('Top 5 Family Class Probabilities') plt.xlim(0, 100) plt.close(figfam) img = fig_to_img(figfam) if len(protein_sequence) < 100: return decoded_label_fam, img, None, "**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}" return decoded_label_fam, img, None, family_info def process_donor_sequence(protein_sequence, modeldon, label_dict): encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt") input_idsdon = encoded_input["input_ids"] attention_maskdon = encoded_input["attention_mask"] with torch.no_grad(): outputdon = modeldon(input_idsdon, attention_mask=attention_maskdon) logitsdon = outputdon.logits probabilitiesdon = F.softmax(logitsdon, dim=1) _, predicted_labelsdon = torch.max(logitsdon, dim=1) predicted_label_index_don = predicted_labelsdon.item() # Assuming single sample prediction decoded_label_don = label_dict.get(predicted_label_index_don, "Unknown Label") # Decoding label using the dictionary figdon = plt.figure(figsize=(10, 5)) probabilitiesdon_flat = probabilitiesdon.squeeze().tolist() # Flatten probabilities # Extract and sort top 5 label probabilities top3_probs, top3_labels = torch.topk(probabilitiesdon, 3) top3_labels = top3_labels.squeeze().tolist() top3_decoded_labels = [label_dict.get(label, "Unknown") for label in top3_labels] y_posdon = np.arange(len(top3_decoded_labels)) plt.barh(y_posdon, [prob * 100 for prob in top3_probs.squeeze().tolist()], align='center', alpha=0.5) plt.yticks(y_posdon, top3_decoded_labels) plt.xlabel('Probability (%)') plt.title('Top 3 Donor Class Probabilities') plt.xlim(0, 100) plt.close(figdon) img = fig_to_img(figdon) if len(protein_sequence) < 100: return decoded_label_don, img, None, "**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}" return decoded_label_don, img, None def main_function_single(sequence): # Initial preprocessing including BLAST-based model selection protein_sequence, model_fam_path, model_don_path, error_msg = preprocess_protein_sequence(sequence) if error_msg: print(error_msg) return None, None, error_msg, None, None model_config = { "GTA_fam.pth": {"num_labels": 41, "label_dict": GTA_fam_dict}, "GTB_fam.pth": {"num_labels": 27, "label_dict": GTB_fam_dict}, "GTA_don.pth": {"num_labels": 10, "label_dict": GTA_don_dict}, "GTB_don.pth": {"num_labels": 9, "label_dict": GTB_don_dict}, } # Load the model for family classification config_fam = model_config[model_fam_path] model_fam = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=config_fam["num_labels"]) model_fam.load_state_dict(torch.load(model_fam_path, map_location=torch.device('cpu')), strict=False) model_fam.eval() model_fam.to('cpu') # Load the model for donor classification config_don = model_config[model_don_path] model_don = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=config_don["num_labels"]) model_don.load_state_dict(torch.load(model_don_path, map_location=torch.device('cpu')), strict=False) model_don.eval() model_don.to('cpu') print(config_fam["label_dict"]) # Pass the label dictionary along with the model to the processing functions family_label, family_img, _, family_info = process_family_sequence(protein_sequence, model_fam, config_fam["label_dict"]) donor_label, donor_img, _ = process_donor_sequence(protein_sequence, model_don, config_don["label_dict"]) return family_label, family_img, family_info, donor_label, donor_img prediction_imagefam = gr.outputs.Image(type='pil', label="Family prediction graph") prediction_imagedonor = gr.outputs.Image(type='pil', label="Donor prediction graph") with gr.Blocks() as app: gr.Markdown("# Glydentify (alpha v0.5)") with gr.Tab("Single Sequence Prediction"): with gr.Row().style(equal_height=True): with gr.Column(): sequence = gr.inputs.Textbox(lines=16, placeholder='Enter Protein Sequence Here...', label="Protein Sequence") # explanation_checkbox = gr.inputs.Checkbox(label="Show Explanation", default=False) with gr.Column(): with gr.Accordion("Example:"): gr.Markdown(""" \>sp|Q9Y5Z6|B3GT1_HUMAN Beta-1,3-galactosyltransferase 1 OS=Homo sapiens OX=9606 GN=B3GALT1 PE=1 SV=1 MASKVSCLYVLTVVCWASALWYLSITRPTSSYTGSKPFSHLTVARKNFTFGNIRTRPINPHSFEFLINEPNKCEKNIPFLVILIST THKEFDARQAIRETWGDENNFKGIKIATLFLLGKNADPVLNQMVEQESQIFHDIIVEDFIDSYHNLTLKTLMGMRWVATFCSK AKYVMKTDSDIFVNMDNLIYKLLKPSTKPRRRYFTGYVINGGPIRDVRSKWYMPRDLYPDSNYPPFCSGTGYIFSADVAELIYK TSLHTRLLHLEDVYVGLCLRKLGIHPFQNSGFNHWKMAYSLCRYRRVITVHQISPEEMHRIWNDMSSKKHLRC """) family_prediction = gr.outputs.Textbox(label="Predicted family") donor_prediction = gr.outputs.Textbox(label="Predicted donor") info_markdown = gr.Markdown() # Predict and Clear buttons with gr.Row().style(equal_height=True): with gr.Column(): predict_button = gr.Button("Predict") predict_button.click(main_function_single, inputs=[sequence], outputs=[family_prediction, prediction_imagefam, info_markdown, donor_prediction, prediction_imagedonor]) # Family & Donor Section with gr.Row().style(equal_height=True): with gr.Column(): with gr.Accordion("Family Prediction:"): prediction_imagefam.render() # = gr.outputs.Image(type='pil', label="Family prediction graph") with gr.Column(): with gr.Accordion("Donor Prediction:"): prediction_imagedonor.render() # = gr.outputs.Image(type='pil', label="Donor prediction graph") app.launch(show_error=True)