heron_chat_git / app.py
Inoichan's picture
Disable share
d21f404
import logging
import os
import sys
import time
from threading import Thread
import gradio as gr
import torch
from transformers import (
AutoProcessor,
StoppingCriteria,
TextIteratorStreamer,
CLIPImageProcessor,
LlamaTokenizer,
)
if os.path.exists("heron") == False:
os.system(
"git clone https://github.com/turingmotors/heron.git"
"&& export CUDA_HOME=/usr/local/cuda; pip install -e heron"
)
sys.path.insert(0, "./heron")
from heron.models.git_llm.git_japanese_stablelm_alpha import (
GitJapaneseStableLMAlphaConfig,
GitJapaneseStableLMAlphaForCausalLM,
)
from heron.models.git_llm.git_llama import GitLlamaConfig, GitLlamaForCausalLM
logger = logging.getLogger(__name__)
# This class is copied from llava: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py#L51-L74
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
offset = min(output_ids.shape[1] - self.start_len, 3)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def preprocess(history, image):
text = ""
for one_history in history:
text += f"##human: {one_history[0]}\n##gpt: "
# do preprocessing
inputs = processor(
text,
image,
return_tensors="pt",
truncation=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
return inputs
def add_text(textbox, history):
# hard text threshold
if len(textbox) > 512:
textbox = textbox[:512]
history = history + [(textbox, None)]
return "", history
title_markdown = """
# Heron Chat Demo
- Model: [heron-chat-git-Llama-2-7b-v0](https://huggingface.co/turing-motors/heron-chat-git-Llama-2-7b-v0)
- Codes: [Heron](https://github.com/turingmotors/heron)
"""
def stream_bot(imagebox, history):
# do preprocessing
inputs = preprocess(history, imagebox)
# streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
stopping_criteria = KeywordsStoppingCriteria(
[EOS_WORDS], processor.tokenizer, inputs["input_ids"]
)
inputs.update(
dict(
streamer=streamer,
max_new_tokens=max_length,
stopping_criteria=[stopping_criteria],
do_sample=True,
temperature=0.4,
no_repeat_ngram_size=2,
)
)
thread = Thread(target=model.generate, kwargs=inputs)
thread.start()
history[-1][1] = ""
for new_text in streamer:
history[-1][1] += new_text
while history[-1][1].endswith("#"):
history[-1][1] = history[-1][1][:-1]
time.sleep(0.05)
yield history
def regenerate(history):
history[-1] = (history[-1][0], None)
return history
def clear_history():
return [], "", None
def build_demo():
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", visible=True, container=False
)
with gr.Blocks(title="Heron", theme=gr.themes.Base()) as demo:
gr.Markdown(title_markdown)
with gr.Row():
with gr.Column(scale=3):
imagebox = gr.Image(type="pil")
gr.Examples(
examples=[
# [
# "./images/bus_kyoto.png",
# "ใ“ใฎ้“่ทฏใ‚’้‹่ปขใ™ใ‚‹ๆ™‚ใซใฏไฝ•ใซๆฐ—ใ‚’ใคใ‘ใ‚‹ในใใงใ™ใ‹๏ผŸ",
# ],
# [
# "./images/bear.png",
# "ใ“ใฎ็”ปๅƒใซใฏไฝ•ใŒๅ†™ใฃใฆใ„ใพใ™ใ‹๏ผŸ",
# ],
# [
# "./images/water_bus.png",
# "็”ปๅƒใซใฏไฝ•ใŒๅ†™ใฃใฆใ„ใพใ™ใ‹๏ผŸ",
# ],
# [
# "./images/extreme_ironing.jpg",
# "ใ“ใฎ็”ปๅƒใฎ้ข็™ฝใ„็‚นใฏไฝ•ใงใ™ใ‹๏ผŸ",
# ],
# [
# "./images/heron.png",
# "ใ“ใฎ็”ปๅƒใฏใฉใ†ใ„ใ†็‚นใŒ้ข็™ฝใ„ใงใ™ใ‹๏ผŸ",
# ],
[
"./images/bus_kyoto.png",
"What should you be careful of when driving on this road?",
],
[
"./images/bear.png",
"What is shown in this image?",
],
[
"./images/water_bus.png",
"What is depicted in the picture?",
],
[
"./images/extreme_ironing.jpg",
"What is the unusual aspect of this image?",
],
[
"./images/heron.png",
"What is intriguing about this picture?",
],
],
inputs=[imagebox, textbox],
)
with gr.Column(scale=6):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Heron Chatbot",
visible=True,
height=550,
avatar_images=("./images/user_icon.png", "./images/heron.png"),
)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=60):
submit_btn = gr.Button(value="Submit", visible=True)
with gr.Row():
regenerate_btn = gr.Button(value="Regenerate", visible=True)
clear_btn = gr.Button(value="Clear history", visible=True)
regenerate_btn.click(regenerate, chatbot, chatbot).then(
stream_bot,
[imagebox, chatbot],
[chatbot],
)
clear_btn.click(clear_history, None, [chatbot, textbox, imagebox])
textbox.submit(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then(
stream_bot,
[imagebox, chatbot],
[chatbot],
)
submit_btn.click(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then(
stream_bot,
[imagebox, chatbot],
[chatbot],
)
return demo
if __name__ == "__main__":
EOS_WORDS = "##"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_length = 512
vision_model_name = "openai/clip-vit-large-patch14-336"
MODEL_NAME = "turing-motors/heron-chat-git-Llama-2-7b-v0"
PROCESSOR_PATH = "turing-motors/heron-chat-git-Llama-2-7b-v0"
# prepare a pretrained model
git_config = GitLlamaConfig.from_pretrained(MODEL_NAME)
git_config.set_vision_configs(
num_image_with_embedding=1, vision_model_name=vision_model_name
)
model = GitLlamaForCausalLM.from_pretrained(
MODEL_NAME, config=git_config, torch_dtype=torch.float16
)
model.eval()
model.to(device)
# prepare a processor
processor = AutoProcessor.from_pretrained(PROCESSOR_PATH)
demo = build_demo()
demo.queue(concurrency_count=1, max_size=5, api_open=False).launch(share=False)