yangheng commited on
Commit
a233a6e
·
verified ·
1 Parent(s): cfdab25

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
3
+ import torch
4
+
5
+
6
+ model_name = "yangheng/PlantRNA-FM"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForMaskedLM.from_pretrained(model_name)
9
+
10
+
11
+ def predict_rna(sequence):
12
+
13
+ inputs = tokenizer(sequence, return_tensors="pt")
14
+ mask_token_index = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1] # 找到 <mask> 的位置
15
+
16
+
17
+ with torch.no_grad():
18
+ outputs = model(**inputs)
19
+
20
+
21
+ mask_token_logits = outputs.logits[0, mask_token_index, :]
22
+ predicted_token_ids = torch.argmax(mask_token_logits, dim=-1)
23
+ predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_token_ids)
24
+
25
+
26
+ return " ".join(predicted_tokens)
27
+
28
+
29
+ input_text = gr.Textbox(lines=2, placeholder="Input RNA Sequence with <mask>, e.g., AAAGAGTCATATACGATATTGTCGACCGTGG<mask>AGAGAGAAGAATGTACGATTGGAGT")
30
+ output_text = gr.Textbox()
31
+
32
+ app = gr.Interface(
33
+ fn=predict_rna,
34
+ inputs=input_text,
35
+ outputs=output_text,
36
+ title="Zero-shot PlantFM-RNA MNM Inference",
37
+ description="Zero-shot PlantFM-RNA MNM Inference: Predicts only the <mask> tokens."
38
+ )
39
+
40
+ if __name__ == "__main__":
41
+ app.launch()