File size: 6,336 Bytes
4abf8fb
85ab89d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation_esm import Chat, CONV_VISION
import esm

# ProteinGPT Initialization Function
def initialize_chat(args):
    cfg = Config(args)
    model_config = cfg.model_cfg
    model_config.device_8bit = 0
    model_cls = registry.get_model_class(model_config.arch)
    model = model_cls.from_config(model_config).to('cpu')
    vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
    vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
    chat = Chat(model, vis_processor, device='cpu')
    return chat

# Gradio Reset Function
def gradio_reset(chat_state, img_list):
    if chat_state is not None:
        chat_state.messages = []
    if img_list is not None:
        img_list = []
    return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your protein structure and sequence first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list

# Upload Function
def upload_protein(structure, sequence, text_input, chat_state):
    # Check if structure and sequence files are valid
    if structure is None or not structure.endswith(".pt"):
        return (None, None, None, gr.update(placeholder="Invalid structure file, must be a .pt file.", interactive=True), chat_state, None)
    if sequence is None or not sequence.endswith(".pt"):
        return (None, None, None, gr.update(placeholder="Invalid sequence file, must be a .pt file.", interactive=True), chat_state, None)
    
    # Load protein structure and sequence
    pdb_embedding = torch.load(structure, map_location=torch.device('cpu'))
    sample_pdb = pdb_embedding.to('cpu')
    
    seq_embedding = torch.load(sequence, map_location=torch.device('cpu'))
    sample_seq = seq_embedding.to('cpu')

    # Initialize the conversation state
    chat_state = CONV_VISION.copy()
    img_list = []
    
    # Upload protein data
    llm_message = chat.upload_protein(sample_pdb, sample_seq, chat_state, img_list)
    
    # Return the required outputs
    return (gr.update(interactive=False),  # Disable structure file input
            gr.update(interactive=False),  # Disable sequence file input
            gr.update(interactive=True, placeholder='Type and press Enter'),  # Enable the text input box
            gr.update(value="Start Chatting", interactive=False),  # Update upload button state
            chat_state,  # Return the conversation state
            img_list)  # Return the list of images (if any)
# Ask Function
def gradio_ask(user_message, chatbot, chat_state):
    if len(user_message) == 0:
        return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
    chat.ask(user_message, chat_state)
    chatbot = chatbot + [[user_message, None]]
    return '', chatbot, chat_state

# Answer Function
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
    img_list = [mat.half() for mat in img_list]
    llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=num_beams, temperature=temperature, max_length=2000)[0]
    chatbot[-1][1] = llm_message
    return chatbot, chat_state, img_list

# Command-line Argument Parsing
def parse_args():
    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--cfg-path", help="path to configuration file.", default='configs/evaluation.yaml')
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )
    args = parser.parse_args()
    return args

# Demo Gradio Interface
title = """<h1 align="center">Demo of ProteinGPT</h1>"""
description = """<h3>Upload your protein sequence and structure and start chatting with your protein!</h3>"""
article = """<div style='display:flex; gap: 0.25rem; '><a href='https://huggingface.co/AI-BIO/ProteinGPT-Llama3'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://arxiv.org/abs/2408.11363'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>"""

args = parse_args()  # Parse arguments to get config and model info
chat = initialize_chat(args)  # Initialize ProteinGPT model

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown(article)

    with gr.Row():
        with gr.Column(scale=0.5):
            structure = gr.File(type="filepath", label="Upload Protein Structure", show_label=True)
            sequence = gr.File(type="filepath", label="Upload Protein Sequence", show_label=True)
            upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
            clear = gr.Button("Restart")
            num_beams = gr.Slider(minimum=1, maximum=5, value=1, step=1, interactive=True, label="Beam search numbers")
            temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature")

        with gr.Column():
            chat_state = gr.State()
            img_list = gr.State()
            chatbot = gr.Chatbot(label='ProteinGPT')
            text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
    
    upload_button.click(upload_protein, 
                    [structure, sequence, text_input, chat_state], 
                    [structure, sequence, text_input, upload_button, chat_state, img_list])
    text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list])
    clear.click(gradio_reset, [chat_state, img_list], [chatbot, structure, sequence, text_input, upload_button, chat_state, img_list], queue=False)

demo.launch(share=True)