import gradio as gr import argparse import os import random import numpy as np import torch import torch.backends.cudnn as cudnn from minigpt4.common.config import Config from minigpt4.common.dist_utils import get_rank from minigpt4.common.registry import registry from minigpt4.conversation.conversation_esm import Chat, CONV_VISION import esm # ProteinGPT Initialization Function def initialize_chat(args): cfg = Config(args) model_config = cfg.model_cfg model_config.device_8bit = 0 model_cls = registry.get_model_class(model_config.arch) model = model_cls.from_config(model_config).to('cpu') vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) chat = Chat(model, vis_processor, device='cpu') return chat # Gradio Reset Function def gradio_reset(chat_state, img_list): if chat_state is not None: chat_state.messages = [] if img_list is not None: img_list = [] return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your protein structure and sequence first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list # Upload Function def upload_protein(structure, sequence, text_input, chat_state): # Check if structure and sequence files are valid if structure is None or not structure.endswith(".pt"): return (None, None, None, gr.update(placeholder="Invalid structure file, must be a .pt file.", interactive=True), chat_state, None) if sequence is None or not sequence.endswith(".pt"): return (None, None, None, gr.update(placeholder="Invalid sequence file, must be a .pt file.", interactive=True), chat_state, None) # Load protein structure and sequence pdb_embedding = torch.load(structure, map_location=torch.device('cpu')) sample_pdb = pdb_embedding.to('cpu') seq_embedding = torch.load(sequence, map_location=torch.device('cpu')) sample_seq = seq_embedding.to('cpu') # Initialize the conversation state chat_state = CONV_VISION.copy() img_list = [] # Upload protein data llm_message = chat.upload_protein(sample_pdb, sample_seq, chat_state, img_list) # Return the required outputs return (gr.update(interactive=False), # Disable structure file input gr.update(interactive=False), # Disable sequence file input gr.update(interactive=True, placeholder='Type and press Enter'), # Enable the text input box gr.update(value="Start Chatting", interactive=False), # Update upload button state chat_state, # Return the conversation state img_list) # Return the list of images (if any) # Ask Function def gradio_ask(user_message, chatbot, chat_state): if len(user_message) == 0: return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state chat.ask(user_message, chat_state) chatbot = chatbot + [[user_message, None]] return '', chatbot, chat_state # Answer Function def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): img_list = [mat.half() for mat in img_list] llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=num_beams, temperature=temperature, max_length=2000)[0] chatbot[-1][1] = llm_message return chatbot, chat_state, img_list # Command-line Argument Parsing def parse_args(): parser = argparse.ArgumentParser(description="Demo") parser.add_argument("--cfg-path", help="path to configuration file.", default='configs/evaluation.yaml') parser.add_argument( "--options", nargs="+", help="override some settings in the used config, the key-value pair " "in xxx=yyy format will be merged into config file (deprecate), " "change to --cfg-options instead.", ) args = parser.parse_args() return args # Demo Gradio Interface title = """

Demo of ProteinGPT

""" description = """

Upload your protein sequence and structure and start chatting with your protein!

""" article = """
""" args = parse_args() # Parse arguments to get config and model info chat = initialize_chat(args) # Initialize ProteinGPT model with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(description) gr.Markdown(article) with gr.Row(): with gr.Column(scale=0.5): structure = gr.File(type="filepath", label="Upload Protein Structure", show_label=True) sequence = gr.File(type="filepath", label="Upload Protein Sequence", show_label=True) upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") clear = gr.Button("Restart") num_beams = gr.Slider(minimum=1, maximum=5, value=1, step=1, interactive=True, label="Beam search numbers") temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature") with gr.Column(): chat_state = gr.State() img_list = gr.State() chatbot = gr.Chatbot(label='ProteinGPT') text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False) upload_button.click(upload_protein, [structure, sequence, text_input, chat_state], [structure, sequence, text_input, upload_button, chat_state, img_list]) text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]) clear.click(gradio_reset, [chat_state, img_list], [chatbot, structure, sequence, text_input, upload_button, chat_state, img_list], queue=False) demo.launch(share=True)