File size: 2,676 Bytes
272d619
 
 
db94cbd
91f37c6
421c348
 
91f37c6
 
272d619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import os 
import subprocess
from transformers import AutoModelForSequenceClassification,AutoTokenizer

model_names = ['plant-dnabert','plant-nucleotide-transformer','plant-dnagpt',
               'plant-dnagemma','dnabert2','nucleotide-transformer-v2-100m','agront-1b']
tokenizer_type = "BPE"
model_names = [x + '-' + tokenizer_type if x.startswith("plant") else x for x in model_names]

def inference(seq,model,task):
    if not seq:
        gr.Warning("No sequence provided, use the default sequence.")
        seq = placeholder
    # Load model and tokenizer
    model_name = f'zhangtaolab/{model}-{task}'
    model = AutoModelForSequenceClassification.from_pretrained(model_name,ignore_mismatched_sizes=True)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Inference
    inputs = tokenizer(seq, return_tensors='pt', padding=True, truncation=True, max_length=1024)
    outputs = model(**inputs)
    result = outputs.logits.item()
    return result

placeholder = 'TACTCTAATCGTATCAGCTGCACTTGCGTACAGGCTACCGGCGTCCTCAGCCACGTAAGAAAAGGCCCAATAAAGGCCCAACTACAACCAGCGGATATATATACTGGAGCCTGGCGAGATCACCCTAACCCCTCACACTCCCATCCAGCCGCCACCAGGTGCAGAGTGTT'
css = """
.gradio-container {
    max-width: 900px; 
    margin: auto; 
    padding: 20px;
} 
.btn-primary {
    background-color: #8e44ad; 
    border-color: #8e44ad;
}
"""    
# 创建 Gradio 接口

with gr.Blocks(css=css) as demo:
    gr.HTML(
        """
        <h1 style="text-align: center;">Promoter strength in protoplast predicted by plant LLMs</h1>
        """
    )
    with gr.Row():
        with gr.Column(scale=1):
            drop1 = gr.Dropdown(choices=['promoter_strength_protoplast'],
                            label="Selected Task",
                            interactive=False,
                            value="promoter_strength_protoplast")
        with gr.Column(scale=1):
            drop2 = gr.Dropdown(choices=model_names,
                                label="Select Model",
                                interactive=True,
                                value=model_names[0])
    with gr.Row():
        with gr.Column(scale=1):
            input_box = gr.Textbox(placeholder=placeholder, label="Enter promoter Sequence", lines=4)
        with gr.Column(scale=1):
            output_box = gr.Textbox(label="Output", lines=4)
    with gr.Row():
        submit_button = gr.Button("Submit", variant="primary")
        clear_button = gr.Button("Clear")
    submit_button.click(inference, inputs=[input_box,drop2,drop1], outputs=output_box)
    clear_button.click(lambda: ("", ""), inputs=None, outputs=[input_box, output_box])
# 启动 Gradio 接口
demo.launch()