File size: 2,795 Bytes
77fb246
 
 
58ac50d
77fb246
 
 
58ac50d
77fb246
 
58ac50d
77fb246
 
58ac50d
77fb246
 
 
58ac50d
77fb246
 
3606fab
 
 
 
77fb246
 
 
 
 
 
58ac50d
77fb246
 
 
58ac50d
77fb246
 
 
 
58ac50d
77fb246
 
 
26ab4df
 
77fb246
92ec44f
77fb246
 
 
58ac50d
77fb246
 
3606fab
 
77fb246
 
 
 
 
 
 
 
 
 
 
 
 
 
58ac50d
3606fab
77fb246
 
58ac50d
 
 
77fb246
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import re
import torch
import gradio as gr
from tqdm import tqdm
from datasets import load_dataset, DatasetDict
from transformers import AutoModelForCausalLM, AutoTokenizer

# Automatically detect GPU or use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Default model path
model_tokenizer_path = "zehui127/Omni-DNA-Multitask"

# Load tokenizer and model with trusted remote code
tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_tokenizer_path, trust_remote_code=True).to(device)

# List of available tasks
tasks = ['H3', 'H4', 'H3K9ac', 'H3K14ac', 'H4ac', 'H3K4me1', 'H3K4me2', 'H3K4me3', 'H3K36me3', 'H3K79me3']
mapping = {'1':'It is a',
           '0':'It is not a',
           'No valid prediction':'Cannot be determined whether or not it is a',
          }
def preprocess_response(response, mask_token="[MASK]"):
    """Extracts the response after the [MASK] token."""
    if mask_token in response:
        response = response.split(mask_token, 1)[1]
    response = re.sub(r'^[\sATGC]+', '', response)
    return response

def generate(dna_sequence, task_type, sample_num=1):
    """
    Generates a response based on the DNA sequence and selected task.

    Args:
        dna_sequence (str): The input DNA sequence.
        task_type (str): The selected task type.
        sample_num (int): Number of samples for the generation process.

    Returns:
        str: Predicted function label.
    """
    if task_type is None:
        task_type = 'H3'
    dna_sequence = dna_sequence + task_type +"[MASK]"

    tokenized_message = tokenizer(
        [dna_sequence], return_tensors='pt', return_token_type_ids=False, add_special_tokens=True
    ).to(device)

    response = model.generate(**tokenized_message, max_new_tokens=sample_num, do_sample=False)
    reply = tokenizer.batch_decode(response, skip_special_tokens=False)[0].replace(" ", "")
    pred = extract_label(reply, task_type)
    return f"{mapping[pred]} {task_type}"

def extract_label(message, task_type):
    """Extracts the prediction label from the model's response."""
    task_type = '[MASK]'
    answer = message.split(task_type)[1]
    match = re.search(r'\d+', answer)
    return match.group() if match else "No valid prediction"

# Gradio interface
interface = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(label="Input DNA Sequence", placeholder="Enter a DNA sequence"),
        gr.Dropdown(choices=tasks, label="Select Task Type"),
    ],
    outputs=gr.Textbox(label="Predicted Type"),
    title="Omni-DNA Multitask Prediction",
    description="Select a DNA-related task and input a sequence to generate function predictions.",
)

if __name__ == "__main__":
    interface.launch()