import gradio as gr from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch.nn.functional as F placeholder = 'ATGGAACTCATGAAGACGTTAGATCTTCACAAGAGGATATTTTCCGAATTTAGTGATGAACAATCAAGAGTGTCATACACTGCAAAAATCTATCAAGAACAAATAAAAGCGGCAAAAGGGAGGTTGCCTGATAGTAGTGTAAAGCAATTAGGTGTCTGGCAACTTCATGTTTTCCTCAAAAGATGTGAAAAAGCACCCAACCAGGACAATACGACATCAGGAATTCTGTAA' model_names = ['plant-dnabert', 'plant-dnagpt', 'plant-nucleotide-transformer', 'plant-dnagemma', 'dnabert2', 'nucleotide-transformer-v2-100m', 'agront-1b'] tokenizer_type = "singlebase" model_names = [x + '-' + tokenizer_type if x.startswith("plant") else x for x in model_names] task_map = { "promoter": ["Not promoter", "Core promoter"], "conservation": ["Not conserved", "Conserved"], "H3K27ac": ["Not H3K27ac", "H3K27ac"], "H3K27me3": ["Not H3K27me3", "H3K27me3"], "H3K4me3": ["Not H3K4me3", "H3K4me3"], "lncRNAs": ["Not lncRNA", "lncRNA"], "open_chromatin": ['Not open chromatin', 'Full open chromatin', 'Partial open chromatin'], } task_lists = task_map.keys() def inference(seq,model,task): if not seq: gr.Warning("No sequence provided, use the default sequence.") seq = placeholder # Load model and tokenizer model_name = f'zhangtaolab/{model}-{task}' model = AutoModelForSequenceClassification.from_pretrained(model_name,ignore_mismatched_sizes=True) tokenizer = AutoTokenizer.from_pretrained(model_name) # Inference inputs = tokenizer(seq, return_tensors='pt', padding=True, truncation=True, max_length=512) outputs = model(**inputs) probabilities = F.softmax(outputs.logits,dim=-1).tolist()[0] #Map probabilities to labels labels = task_map[task] result = {labels[i]: probabilities[i] for i in range(len(labels))} return result # Create Gradio interface with gr.Blocks() as demo: gr.HTML( """

Prediction of long non-coding RNAs (lncRNAs) in plant with LLMs

""" ) with gr.Row(): drop1 = gr.Dropdown(choices=task_lists, label="Selected Task", interactive=False, value='lncRNAs') drop2 = gr.Dropdown(choices=model_names, label="Select Model", interactive=True, value=model_names[0]) seq_input = gr.Textbox(label="Input Sequence", lines=6, placeholder=placeholder) with gr.Row(): predict_btn = gr.Button("Predict",variant="primary") clear_btn = gr.Button("Clear") output = gr.Label(label="Predict result") predict_btn.click(inference, inputs=[seq_input,drop2, drop1], outputs=output) clear_btn.click(lambda: ("", None), inputs=[], outputs=[seq_input, output]) # Launch Gradio app demo.launch()