EdwardoSunny's picture
anon
236b7d5
raw
history blame
6.24 kB
import gradio as gr
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation_esm import Chat, CONV_VISION
import esm
# ProteinGPT Initialization Function
def initialize_chat(args):
cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = 0
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cpu')
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cpu')
return chat
# Gradio Reset Function
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your protein structure and sequence first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
# Upload Function
def upload_protein(structure, sequence, text_input, chat_state):
# Check if structure and sequence files are valid
if structure is None or not structure.endswith(".pt"):
return (None, None, None, gr.update(placeholder="Invalid structure file, must be a .pt file.", interactive=True), chat_state, None)
if sequence is None or not sequence.endswith(".pt"):
return (None, None, None, gr.update(placeholder="Invalid sequence file, must be a .pt file.", interactive=True), chat_state, None)
# Load protein structure and sequence
pdb_embedding = torch.load(structure, map_location=torch.device('cpu'))
sample_pdb = pdb_embedding.to('cpu')
seq_embedding = torch.load(sequence, map_location=torch.device('cpu'))
sample_seq = seq_embedding.to('cpu')
# Initialize the conversation state
chat_state = CONV_VISION.copy()
img_list = []
# Upload protein data
llm_message = chat.upload_protein(sample_pdb, sample_seq, chat_state, img_list)
# Return the required outputs
return (gr.update(interactive=False), # Disable structure file input
gr.update(interactive=False), # Disable sequence file input
gr.update(interactive=True, placeholder='Type and press Enter'), # Enable the text input box
gr.update(value="Start Chatting", interactive=False), # Update upload button state
chat_state, # Return the conversation state
img_list) # Return the list of images (if any)
# Ask Function
def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
chat.ask(user_message, chat_state)
chatbot = chatbot + [[user_message, None]]
return '', chatbot, chat_state
# Answer Function
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
img_list = [mat.half() for mat in img_list]
llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=num_beams, temperature=temperature, max_length=2000)[0]
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
# Command-line Argument Parsing
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", help="path to configuration file.", default='configs/evaluation.yaml')
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
# Demo Gradio Interface
title = """<h1 align="center">Demo of ProteinGPT</h1>"""
description = """<h3>Upload your protein sequence and structure and start chatting with your protein!</h3>"""
article = """<div style='display:flex; gap: 0.25rem; '><a href='https://huggingface.co/AI-BIO/ProteinGPT-Llama3'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com'><img src='https://img.shields.io/badge/Github-Code-blue'></a></div>"""
args = parse_args() # Parse arguments to get config and model info
chat = initialize_chat(args) # Initialize ProteinGPT model
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(article)
with gr.Row():
with gr.Column(scale=0.5):
structure = gr.File(type="filepath", label="Upload Protein Structure", show_label=True)
sequence = gr.File(type="filepath", label="Upload Protein Sequence", show_label=True)
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Restart")
num_beams = gr.Slider(minimum=1, maximum=5, value=1, step=1, interactive=True, label="Beam search numbers")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature")
with gr.Column():
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label='ProteinGPT')
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
upload_button.click(upload_protein,
[structure, sequence, text_input, chat_state],
[structure, sequence, text_input, upload_button, chat_state, img_list])
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list])
clear.click(gradio_reset, [chat_state, img_list], [chatbot, structure, sequence, text_input, upload_button, chat_state, img_list], queue=False)
demo.launch(share=True)