Monkey / app.py
echo840's picture
Add application file
d28c270
raw
history blame contribute delete
No virus
15.1 kB
from argparse import ArgumentParser
from pathlib import Path
import copy
import gradio as gr
import os
import re
import secrets
import tempfile
from PIL import Image
from monkey_model.modeling_monkey import MonkeyLMHeadModel
from monkey_model.tokenization_qwen import QWenTokenizer
from monkey_model.configuration_monkey import MonkeyConfig
import shutil
from pathlib import Path
import json
DEFAULT_CKPT_PATH = 'echo840/Monkey' # '/home/zhangli/demo/'
BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
title_markdown = ("""
# Welcome to Monkey
Hello! I'm Monkey, a Large Language and Vision Assistant. Before talking to me, please read the **Operation Guide** and **Terms of Use**.
> Note: This demo represents a more advanced iteration of the chat system, building upon the previous version to deliver an enhanced interactive experience. As a result, we cannot guarantee that the question-answering scenarios presented in the paper can be replicated accurately using this updated version.
## Operation Guide
Click the **Upload** button to upload an image. Then, you can get Monkey's answer in two ways:
- Click the **Generate** and Monkey will generate a description of the image.
- Enter the question in the dialog box, click the **Submit**, and Monkey will answer the question based on the image.
- Click **Clear History** to clear the current image and Q&A content.
""")
policy_markdown = ("""
## Terms of Use
By using this service, users are required to agree to the following terms:
- Monkey is for research use only and unauthorized commercial use is prohibited. For any query, please contact the author.
- Monkey's generation capabilities are limited, so we recommend that users do not rely entirely on its answers.
- Monkey's security measures are limited, so we cannot guarantee that the output is completely appropriate. We strongly recommend that users do not intentionally guide Monkey to generate harmful content, including hate speech, discrimination, violence, pornography, deception, etc.
""")
# ## Some Prompt Examples
# In order to generate more detailed captions, we provide some input examples so that you can conduct more interesting explorations.
# - Generate the detailed caption in English.
# - Explain the visual content of the image in great detail.
# - Analyze the image in a comprehensive and detailed manner.
# - Describe the image in as much detail as possible in English without duplicating it.
# - Describe the image in as much detail as possible in English, including as many elements from the image as possible, but without repetition.
def _get_args():
parser = ArgumentParser()
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
help="Checkpoint name or path, default to %(default)r")
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
parser.add_argument("--share", action="store_true", default=True,
help="Create a publicly shareable link for the interface.")
parser.add_argument("--inbrowser", action="store_true", default=False,
help="Automatically launch the interface in a new tab on the default browser.")
parser.add_argument("--server-port", type=int, default=8000,
help="Demo server port.")
parser.add_argument("--server-name", type=str, default="127.0.0.1",
help="Demo server name.")
args = parser.parse_args()
return args
def _load_model_tokenizer(args):
tokenizer = QWenTokenizer.from_pretrained(
args.checkpoint_path, trust_remote_code=True)
if args.cpu_only:
device_map = "cpu"
else:
device_map = "cuda"
model = MonkeyLMHeadModel.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
).eval()
# model.generation_config = GenerationConfig.from_pretrained(
# args.checkpoint_path, trust_remote_code=True, resume_download=True,
# )
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
return model, tokenizer
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def _launch_demo(args, model, tokenizer):
def predict(_chatbot, task_history):
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
question = _parse_text(query)
# print("User: " + _parse_text(query))
full_response = ""
img_path = _chatbot[0][0][0]
try:
Image.open(img_path)
except:
response = "Please upload a picture."
_chatbot[-1] = (_parse_text(chat_query), response)
full_response = _parse_text(response)
task_history[-1] = (query, full_response)
# print("Monkey: " + _parse_text(full_response))
return _chatbot
query = f'<img>{img_path}</img> {question} Answer: '
print(query)
all_history = query
all_history_0 = ''
if len(_chatbot) > 2:
all_history = ''
for conv in _chatbot[1:-1]:
q = conv[0]
a = conv[1]
all_history_0 = all_history + f'{q} Answer: {a} '
all_history = all_history_0 + f'<img>{img_path}</img> ' # 1288 tokens
all_history = all_history + f'{question} Answer: '
print(all_history)
tokens = all_history.split()
last_2048_tokens = tokens[-760:]
all_history = " ".join(last_2048_tokens)
print(all_history)
# input_ids = tokenizer(query, return_tensors='pt', padding='longest')
input_ids = tokenizer(all_history, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=False,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
length_penalty=3,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
_chatbot[-1] = (_parse_text(chat_query), response)
full_response = _parse_text(response)
# with open('./history/question_answer.jsonl', 'a',encoding="utf-8") as file: # 使用 'a' 模式打开文件,表示以追加模式写入
# data = {query:response}
# json_line = json.dumps(data)
# file.write(json_line + '\n')
# with open('./history/all_history_together.jsonl', 'a',encoding="utf-8") as file: # 使用 'a' 模式打开文件,表示以追加模式写入
# data = f'<img>{img_path}</img> ' + all_history_0 + f'{question} Answer: {full_response}'
# json_line = json.dumps(data)
# file.write(json_line + '\n')
task_history[-1] = (query, full_response)
print("Monkey: " + _parse_text(full_response))
return _chatbot
def caption(_chatbot, task_history):
query = "Describe the image in as much detail as possible in English, including as many elements from the image as possible, but without repetition. Answer: "
chat_query = "Describe the image in as much detail as possible in English, including as many elements from the image as possible, but without repetition. Answer: "
question = _parse_text(query)
print("User: " + _parse_text(query))
full_response = ""
try:
img_path = _chatbot[0][0][0]
Image.open(img_path)
except:
response = "Please upload a picture."
_chatbot.append((None, response))
full_response = _parse_text(response)
task_history.append((None, full_response))
print("Monkey: " + _parse_text(full_response))
return _chatbot
img_path = _chatbot[0][0][0]
query = f'<img>{img_path}</img> {chat_query} '
print(query)
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=True,
temperature=0.7,
max_new_tokens=250,
min_new_tokens=1,
length_penalty=3,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
_chatbot.append((None, response))
full_response = _parse_text(response)
task_history.append((None, full_response))
print("Monkey: " + _parse_text(full_response))
return _chatbot
def add_text(history, task_history, text):
task_text = text
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
task_text = text[:-1]
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
# print(history, task_history, text)
return history, task_history, ""
def add_file(history, task_history, file):
save_path = os.path.join("./history/test_image",file.name.split("/")[-2])
Path(save_path).mkdir(exist_ok=True,parents=True)
shutil.copy(file.name,save_path)
history = [((file.name,), None)]
task_history = [((file.name,), None)]
# print(history, task_history, file)
return history, task_history
def reset_user_input():
return gr.update(value="")
def reset_state(task_history):
# with open('./history/all_history_separate.jsonl', 'a',encoding="utf-8") as file: # 使用 'a' 模式打开文件,表示以追加模式写入
# data = task_history
# json_line = json.dumps(data)
# file.write(json_line + '\n')
task_history.clear()
return []
with gr.Blocks() as demo:
gr.Markdown(title_markdown)
chatbot = gr.Chatbot(label='Monkey', elem_classes="control-height", height=600,avatar_images=("./images/logo_user.png","./images/logo_monkey.png"),layout="bubble",bubble_full_width=False,show_copy_button=True)
query = gr.Textbox(lines=1, label='Input')
task_history = gr.State([])
with gr.Row():
empty_bin = gr.Button("Clear History")
submit_btn = gr.Button("Submit")
generate_btn_en = gr.Button("Generate")
addfile_btn = gr.UploadButton("Upload", file_types=["image"])
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
predict, [chatbot, task_history], [chatbot], show_progress=True
)
generate_btn_en.click(caption, [chatbot, task_history], [chatbot], show_progress=True)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True,scroll_to_output=True)
with gr.Row(variant="compact"):
with gr.Column(scale=2):
with gr.Row():
a = gr.Image(Image.open("./images/logo_monkey.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False,render=False)
b = gr.Image(Image.open("./images/logo_hust.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False,render=False)
with gr.Column(scale=4):
with gr.Row():
a = gr.Image(Image.open("./images/logo_monkey.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False)
c = gr.Image(Image.open("./images/logo_vlr.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False)
b = gr.Image(Image.open("./images/logo_hust.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False)
b = gr.Image(Image.open("./images/logo_king.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False)
with gr.Column(scale=2):
with gr.Row():
a = gr.Image(Image.open("./images/logo_monkey.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False,render=False)
b = gr.Image(Image.open("./images/logo_hust.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False,render=False)
gr.Markdown(policy_markdown)
demo.queue().launch(
server_name="0.0.0.0",
server_port=7682,
share=True
)
def main():
args = _get_args()
model, tokenizer = _load_model_tokenizer(args)
_launch_demo(args, model, tokenizer)
if __name__ == '__main__':
main()