File size: 7,312 Bytes
2f1b920 cc7fb29 2f1b920 5831935 2f1b920 5831935 2f1b920 4b12912 2f1b920 4b12912 2f1b920 4b12912 f4155e6 4b12912 2f1b920 4b12912 2f1b920 4b12912 f4155e6 a21b165 f4155e6 4b12912 2f1b920 4b12912 2f1b920 4b12912 2f1b920 4b12912 5160608 f4155e6 f2da71e 23a697f f2da71e f4155e6 f2da71e f4155e6 f2da71e f4155e6 a21b165 f0b5d73 f4155e6 2f1b920 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, EsmForSequenceClassification
from transformers import set_seed
import torch
import torch.nn as nn
import warnings
from tqdm import tqdm
import gradio as gr
warnings.filterwarnings('ignore')
device = "cpu"
model_checkpoint1 = "facebook/esm2_t12_35M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint1)
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.bert1 = EsmForSequenceClassification.from_pretrained(model_checkpoint1, num_labels=3000) # 3000
self.bn1 = nn.BatchNorm1d(256)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(64)
self.relu = nn.LeakyReLU()
self.fc1 = nn.Linear(3000, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 64)
self.output_layer = nn.Linear(64, 2)
self.dropout = nn.Dropout(0.3) # 0.3
def forward(self, x):
with torch.no_grad():
bert_output = self.bert1(input_ids=x['input_ids'],
attention_mask=x['attention_mask'])
# output_feature = bert_output["logits"]
# print(output_feature.size())
# output_feature = self.bn1(self.fc1(output_feature))
# output_feature = self.bn2(self.fc1(output_feature))
# output_feature = self.relu(self.bn3(self.fc3(output_feature)))
# output_feature = self.dropout(self.output_layer(output_feature))
output_feature = self.dropout(bert_output["logits"])
output_feature = self.dropout(self.relu(self.bn1(self.fc1(output_feature))))
output_feature = self.dropout(self.relu(self.bn2(self.fc2(output_feature))))
output_feature = self.dropout(self.relu(self.bn3(self.fc3(output_feature))))
output_feature = self.dropout(self.output_layer(output_feature))
# return torch.sigmoid(output_feature),output_feature
return torch.softmax(output_feature, dim=1)
def AMP(test_sequences, model):
# 保持 AMP 函数不变,只处理传入的 test_sequences 数据
max_len = 18
test_data = tokenizer(test_sequences, max_length=max_len, padding="max_length", truncation=True,
return_tensors='pt')
model = model.to(device)
model.eval()
out_probability = []
with torch.no_grad():
predict = model(test_data)
out_probability.extend(np.max(np.array(predict.cpu()), axis=1).tolist())
test_argmax = np.argmax(predict.cpu(), axis=1).tolist()
id2str = {0: "non-AMP", 1: "AMP"}
return id2str[test_argmax[0]], out_probability[0]
def classify_sequence(sequence):
# Check if the sequence is a valid amino acid sequence and has a length of at least 3
valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY")
sequence = sequence.upper()
if all(aa in valid_amino_acids for aa in sequence) and len(sequence) >= 3:
result, probability = AMP(sequence, model)
return "yes" if result == "AMP" else "no"
else:
return "Invalid Sequence"
# 加载模型
model = MyModel()
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
if __name__ == "__main__":
title = """<h1 align="center">🔥AMP Sequence Detector</h1>"""
css = ".json {height: 527px; overflow: scroll;} .json-holder {height: 527px; overflow: scroll;}"
theme = gr.themes.Soft(primary_hue="zinc", secondary_hue="blue", neutral_hue="green",
text_size=gr.themes.sizes.text_lg)
with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""",
theme=theme) as demo:
gr.Markdown("<h1>Diff-AMP</h1>")
gr.HTML(title)
gr.Markdown(
"<p align='center' style='font-size: 20px;'>🔥Welcome to Antimicrobial Peptide Recognition Model. See our <a href='https://github.com/wrab12/diff-amp'>Project</a></p>")
gr.HTML(
'''<center>
<a href="https://huggingface.co/spaces/jackrui/ampD?duplicate=true">
<img src="https://bit.ly/3gLdBN6" alt="Duplicate Space">
</a>
</center>''')
gr.HTML(
'''<center>🌟Note: This is an antimicrobial peptide recognition model derived from Diff-AMP, which is a branch of a comprehensive system integrating generation, recognition, and optimization. In this recognition model, you can simply input a sequence, and it will predict whether it is an antimicrobial peptide. Due to limited website capacity, we can only perform simple predictions.
If you require large-scale computations, please contact my email at wangrui66677@gmail.com. Feel free to reach out if you have any questions or inquiries.</center>''')
# gr.Markdown(
# """
#
# # Welcome to Antimicrobial Peptide Recognition Model
# This is an antimicrobial peptide recognition model derived from Diff-AMP, which is a branch of a comprehensive system integrating generation, recognition, and optimization. In this recognition model, you can simply input a sequence, and it will predict whether it is an antimicrobial peptide. Due to limited website capacity, we can only perform simple predictions.
# If you require large-scale computations, please contact my email at wangrui66677@gmail.com. Feel free to reach out if you have any questions or inquiries.
#
# """)
# 添加示例输入和输出
examples = [
["QGLFFLGAKLFYLLTLFL"],
["FLGLLFHGVHHVGKWIHGLIHGHH"],
["GLMSTLKGAATNAAVTLLNKLQCKLTGTC"]
]
# 创建 Gradio 接口并应用美化样式和示例
iface = gr.Interface(
fn=classify_sequence,
inputs="text",
outputs="text",
# title="AMP Sequence Detector",
examples=examples
)
gr.Markdown(
"<p align='center'><img src='https://pic4.zhimg.com/v2-eb2a7c0e746e67d1768090eec74f6787_b.jpg'></p>")
gr.Markdown("<p align='center' style='font-size: 20px;'>Related job links in the same series: </p>")
gr.Markdown("<p align='center'><a href='https://huggingface.co/spaces/jackrui/ampG'><img style='margin:-0.8em 0 2em 0;' src='https://shields.io/badge/Diff_AMP-Generator-blue' alt='Diff_AMP-Generator-blue'></a></p>"
"<p align='center'><a href='https://huggingface.co/spaces/jackrui/ampPP'><img style='margin:-0.8em 0 2em 0;' src='https://shields.io/badge/Diff_AMP-property_prediction-blue' alt='Diff_AMP-property_prediction-blue'></a></p>")
gr.Markdown('''📝 **Citation**
If our work is useful for your research, please consider citing:
```
waiting...
```
📋 **License**
None
📧 **Contact**
If you have any questions, please feel free to reach me out at <b>wangrui66677@gmail.com</b>.
🤗 **Find Me:**
<style type="text/css">
td {
padding-right: 0px !important;
}
</style>
<table>
<tr>
<td><a href="https://github.com/wrab12"><img style="margin:-0.8em 0 2em 0" src="https://img.shields.io/github/followers/wrab12?style=social" alt="Github Follow"></a></td>
</tr>
</table>
<center><img src='https://api.infinitescript.com/badgen/count?name=jackrui/ampD<ext=Visitors&color=6dc9aa' alt='visitors'></center>
"""
''')
demo.launch()
|