File size: 7,312 Bytes
2f1b920
cc7fb29
2f1b920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5831935
2f1b920
 
 
 
5831935
2f1b920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b12912
2f1b920
 
4b12912
2f1b920
 
4b12912
 
f4155e6
4b12912
 
 
2f1b920
4b12912
 
2f1b920
4b12912
 
 
 
f4155e6
a21b165
f4155e6
 
 
4b12912
 
 
 
 
 
 
 
 
 
 
 
 
 
2f1b920
4b12912
 
 
 
2f1b920
4b12912
2f1b920
4b12912
 
 
 
 
 
 
 
5160608
 
 
 
f4155e6
 
 
 
 
 
f2da71e
23a697f
f2da71e
f4155e6
f2da71e
f4155e6
f2da71e
f4155e6
 
 
 
 
 
 
 
 
 
 
 
a21b165
f0b5d73
f4155e6
2f1b920
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, EsmForSequenceClassification
from transformers import set_seed
import torch
import torch.nn as nn
import warnings
from tqdm import tqdm
import gradio as gr

warnings.filterwarnings('ignore')
device = "cpu"
model_checkpoint1 = "facebook/esm2_t12_35M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint1)


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert1 = EsmForSequenceClassification.from_pretrained(model_checkpoint1, num_labels=3000)  # 3000
        self.bn1 = nn.BatchNorm1d(256)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(64)
        self.relu = nn.LeakyReLU()
        self.fc1 = nn.Linear(3000, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.output_layer = nn.Linear(64, 2)
        self.dropout = nn.Dropout(0.3)  # 0.3

    def forward(self, x):
        with torch.no_grad():
            bert_output = self.bert1(input_ids=x['input_ids'],
                                     attention_mask=x['attention_mask'])
        # output_feature = bert_output["logits"]
        # print(output_feature.size())
        # output_feature = self.bn1(self.fc1(output_feature))
        # output_feature = self.bn2(self.fc1(output_feature))
        # output_feature = self.relu(self.bn3(self.fc3(output_feature)))
        # output_feature = self.dropout(self.output_layer(output_feature))
        output_feature = self.dropout(bert_output["logits"])
        output_feature = self.dropout(self.relu(self.bn1(self.fc1(output_feature))))
        output_feature = self.dropout(self.relu(self.bn2(self.fc2(output_feature))))
        output_feature = self.dropout(self.relu(self.bn3(self.fc3(output_feature))))
        output_feature = self.dropout(self.output_layer(output_feature))
        # return torch.sigmoid(output_feature),output_feature
        return torch.softmax(output_feature, dim=1)


def AMP(test_sequences, model):
    # 保持 AMP 函数不变,只处理传入的 test_sequences 数据
    max_len = 18
    test_data = tokenizer(test_sequences, max_length=max_len, padding="max_length", truncation=True,
                          return_tensors='pt')
    model = model.to(device)
    model.eval()
    out_probability = []
    with torch.no_grad():
        predict = model(test_data)
        out_probability.extend(np.max(np.array(predict.cpu()), axis=1).tolist())
        test_argmax = np.argmax(predict.cpu(), axis=1).tolist()
    id2str = {0: "non-AMP", 1: "AMP"}
    return id2str[test_argmax[0]], out_probability[0]


def classify_sequence(sequence):
    # Check if the sequence is a valid amino acid sequence and has a length of at least 3
    valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY")
    sequence = sequence.upper()

    if all(aa in valid_amino_acids for aa in sequence) and len(sequence) >= 3:
        result, probability = AMP(sequence, model)
        return "yes" if result == "AMP" else "no"
    else:
        return "Invalid Sequence"


# 加载模型
model = MyModel()
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))

if __name__ == "__main__":
    title = """<h1 align="center">🔥AMP Sequence Detector</h1>"""
    css = ".json {height: 527px; overflow: scroll;} .json-holder {height: 527px; overflow: scroll;}"
    theme = gr.themes.Soft(primary_hue="zinc", secondary_hue="blue", neutral_hue="green",
                           text_size=gr.themes.sizes.text_lg)
    with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""",
                      theme=theme) as demo:

        gr.Markdown("<h1>Diff-AMP</h1>")
        gr.HTML(title)


        gr.Markdown(
            "<p align='center' style='font-size: 20px;'>🔥Welcome to Antimicrobial Peptide Recognition Model. See our <a href='https://github.com/wrab12/diff-amp'>Project</a></p>")
        gr.HTML(
            '''<center>
  <a href="https://huggingface.co/spaces/jackrui/ampD?duplicate=true">
    <img src="https://bit.ly/3gLdBN6" alt="Duplicate Space">
  </a>
</center>''')
        gr.HTML(
            '''<center>🌟Note: This is an antimicrobial peptide recognition model derived from Diff-AMP, which is a branch of a comprehensive system integrating generation, recognition, and optimization. In this recognition model, you can simply input a sequence, and it will predict whether it is an antimicrobial peptide. Due to limited website capacity, we can only perform simple predictions.
    If you require large-scale computations, please contact my email at wangrui66677@gmail.com. Feel free to reach out if you have any questions or inquiries.</center>''')

    # gr.Markdown(
    #     """
    #
    # # Welcome to Antimicrobial Peptide Recognition Model
    # This is an antimicrobial peptide recognition model derived from Diff-AMP, which is a branch of a comprehensive system integrating generation, recognition, and optimization. In this recognition model, you can simply input a sequence, and it will predict whether it is an antimicrobial peptide. Due to limited website capacity, we can only perform simple predictions.
    # If you require large-scale computations, please contact my email at wangrui66677@gmail.com. Feel free to reach out if you have any questions or inquiries.
    #
    #     """)

    # 添加示例输入和输出
        examples = [
        ["QGLFFLGAKLFYLLTLFL"],
        ["FLGLLFHGVHHVGKWIHGLIHGHH"],
        ["GLMSTLKGAATNAAVTLLNKLQCKLTGTC"]
    ]

    # 创建 Gradio 接口并应用美化样式和示例
        iface = gr.Interface(
        fn=classify_sequence,
        inputs="text",
        outputs="text",
        # title="AMP Sequence Detector",
        examples=examples
    )
        gr.Markdown(
            "<p align='center'><img src='https://pic4.zhimg.com/v2-eb2a7c0e746e67d1768090eec74f6787_b.jpg'></p>")
        gr.Markdown("<p align='center' style='font-size: 20px;'>Related job links in the same series: </p>")
                    
        gr.Markdown("<p align='center'><a href='https://huggingface.co/spaces/jackrui/ampG'><img style='margin:-0.8em 0 2em 0;' src='https://shields.io/badge/Diff_AMP-Generator-blue' alt='Diff_AMP-Generator-blue'></a></p>"
                    "<p align='center'><a href='https://huggingface.co/spaces/jackrui/ampPP'><img style='margin:-0.8em 0 2em 0;' src='https://shields.io/badge/Diff_AMP-property_prediction-blue' alt='Diff_AMP-property_prediction-blue'></a></p>")
        gr.Markdown('''📝 **Citation**
If our work is useful for your research, please consider citing:
```
waiting...
```
📋 **License**

None

📧 **Contact**

If you have any questions, please feel free to reach me out at <b>wangrui66677@gmail.com</b>.

🤗 **Find Me:**
<style type="text/css">
td {
    padding-right: 0px !important;
}
</style>
<table>
<tr>
    <td><a href="https://github.com/wrab12"><img style="margin:-0.8em 0 2em 0" src="https://img.shields.io/github/followers/wrab12?style=social" alt="Github Follow"></a></td>
    
</tr>
</table>
<center><img src='https://api.infinitescript.com/badgen/count?name=jackrui/ampD&ltext=Visitors&color=6dc9aa' alt='visitors'></center>
"""
''')

    demo.launch()