|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
import torch |
|
|
|
|
|
model_name = "yangheng/PlantRNA-FM" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForMaskedLM.from_pretrained(model_name) |
|
|
|
|
|
def predict_rna(sequence): |
|
|
|
inputs = tokenizer(sequence, return_tensors="pt") |
|
mask_token_index = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1] |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
mask_token_logits = outputs.logits[0, mask_token_index, :] |
|
predicted_token_ids = torch.argmax(mask_token_logits, dim=-1) |
|
predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_token_ids) |
|
|
|
|
|
return " ".join(predicted_tokens) |
|
|
|
|
|
input_text = gr.Textbox(lines=2, placeholder="Input RNA Sequence with <mask>, e.g., AAAGAGTCATATACGATATTGTCGACCGTGG<mask>AGAGAGAAGAATGTACGATTGGAGT") |
|
output_text = gr.Textbox() |
|
|
|
app = gr.Interface( |
|
fn=predict_rna, |
|
inputs=input_text, |
|
outputs=output_text, |
|
title="Zero-shot PlantFM-RNA MNM Inference", |
|
description="Zero-shot PlantFM-RNA MNM Inference: Predicts only the <mask> tokens." |
|
) |
|
|
|
if __name__ == "__main__": |
|
app.launch() |
|
|