heron_chat_blip / app.py
tanahhh's picture
add heron and iron
ccd8557
raw history blame
No virus
7.39 kB
import logging
import os
import sys
import time
from threading import Thread
import gradio as gr
import torch
from transformers import AutoProcessor, StoppingCriteria, TextIteratorStreamer, LlamaTokenizer
os.system(
"git clone https://github.com/turingmotors/heron && cd heron && pip install -e ."
)
sys.path.insert(0, "./heron")
from heron.models.video_blip import VideoBlipForConditionalGeneration, VideoBlipProcessor
logger = logging.getLogger(__name__)
# This class is copied from llava: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py#L51-L74
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
offset = min(output_ids.shape[1] - self.start_len, 3)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def preprocess(history, image):
text = ""
for one_history in history:
text += f"##human: {one_history[0]}\n##gpt: "
# do preprocessing
inputs = processor(
text=text,
images=image,
return_tensors="pt",
truncation=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
inputs["pixel_values"] = inputs["pixel_values"].to(device, torch.float16)
return inputs
def add_text(textbox, history):
history = history + [(textbox, None)]
return "", history
def stream_bot(imagebox, history):
# do preprocessing
inputs = preprocess(history, imagebox)
# streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
do_sample=False,
temperature=0.2,
no_repeat_ngram_size=2,
)
stopping_criteria = KeywordsStoppingCriteria(
[EOS_WORDS], processor.tokenizer, inputs["input_ids"]
)
inputs.update(
dict(
streamer=streamer, max_new_tokens=max_length,
stopping_criteria=[stopping_criteria],
no_repeat_ngram_size=2,
eos_token_id=[processor.tokenizer.pad_token_id]
)
)
thread = Thread(target=model.generate, kwargs=inputs)
thread.start()
history[-1][1] = ""
for new_text in streamer:
history[-1][1] += new_text
history[-1][1] = history[-1][1].replace(EOS_WORDS, "")
time.sleep(0.05)
yield history
def regenerate(history):
history[-1] = (history[-1][0], None)
return history
def clear_history():
return [], "", None
def build_demo():
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", visible=True, container=False
)
with gr.Blocks(title="Heron", theme=gr.themes.Base()) as demo:
with gr.Row():
with gr.Column(scale=3):
imagebox = gr.Image(type="pil")
gr.Examples(
examples=[
[
"./images/bus_kyoto.png",
"この道路を運転する時には何に気をつけるべきですか?",
],
[
"./images/bear.png",
"この画像には何が写っていますか?",
],
[
"./images/water_bus.png",
"画像には何が写っていますか?",
],
[
"./images/extreme_ironing.jpg",
"この画像の面白い点は何ですか?",
],
[
"./images/heron.png",
"この画像はどういう点が面白いですか?",
],
],
inputs=[imagebox, textbox],
)
with gr.Column(scale=6):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Heron Chatbot",
visible=True,
height=550,
avatar_images=("./images/user_icon.png", "./images/heron.png"),
)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=60):
submit_btn = gr.Button(value="Submit", visible=True)
with gr.Row():
regenerate_btn = gr.Button(value="Regenerate", visible=True)
clear_btn = gr.Button(value="Clear history", visible=True)
regenerate_btn.click(regenerate, chatbot, chatbot).then(
stream_bot,
[imagebox, chatbot],
[chatbot],
)
clear_btn.click(clear_history, None, [chatbot, textbox, imagebox])
textbox.submit(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then(
stream_bot,
[imagebox, chatbot],
[chatbot],
)
submit_btn.click(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then(
stream_bot,
[imagebox, chatbot],
[chatbot],
)
return demo
if __name__ == "__main__":
EOS_WORDS = "##"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_length = 512
#MODEL_NAME = "/home/kotaro/video_blip/notebooks/model_output"
MODEL_NAME = "turing-motors/heron-chat-blip-ja-stablelm-base-7b-v0"
#PROCESSOR_PATH = "turing-motors/GIT-Llama2-llava-test"
print("--")
# prepare a pretrained model
model = VideoBlipForConditionalGeneration.from_pretrained(
MODEL_NAME, torch_dtype=torch.float16, ignore_mismatched_sizes=True
)
model = model.half()
model.eval()
model.to(device)
# prepare a processor
#processor = AutoProcessor.from_pretrained(PROCESSOR_PATH, use_fast=False)
processor = VideoBlipProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1", additional_special_tokens=['▁▁'])
processor.tokenizer = tokenizer
demo = build_demo()
demo.queue(max_size=5).launch()