heron_chat_blip / app.py
tanahhh's picture
udpate
ddf556f
raw
history blame
No virus
6.59 kB
import logging
import os
import sys
import time
from threading import Thread
import gradio as gr
import torch
from transformers import AutoProcessor, StoppingCriteria, TextIteratorStreamer, LlamaTokenizer
os.system(
"git clone https://github.com/turingmotors/heron && cd heron && pip install -e ."
)
sys.path.insert(0, "./heron")
from heron.models.video_blip import VideoBlipForConditionalGeneration, VideoBlipProcessor
logger = logging.getLogger(__name__)
# This class is copied from llava: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py#L51-L74
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
offset = min(output_ids.shape[1] - self.start_len, 3)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def preprocess(history, image):
text = ""
for one_history in history:
text += f"##human: {one_history[0]}\n##gpt: "
# do preprocessing
inputs = processor(
text=text,
images=image,
return_tensors="pt",
truncation=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
inputs["pixel_values"] = inputs["pixel_values"].to(device, torch.float16)
return inputs
def add_text(textbox, history):
history = history + [(textbox, None)]
return "", history
def stream_bot(imagebox, history):
# do preprocessing
inputs = preprocess(history, imagebox)
# streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
do_sample=False,
temperature=0.2,
no_repeat_ngram_size=2,
)
stopping_criteria = KeywordsStoppingCriteria(
[EOS_WORDS], processor.tokenizer, inputs["input_ids"]
)
inputs.update(
dict(
streamer=streamer, max_new_tokens=max_length,
stopping_criteria=[stopping_criteria],
no_repeat_ngram_size=2,
eos_token_id=[processor.tokenizer.pad_token_id]
)
)
thread = Thread(target=model.generate, kwargs=inputs)
thread.start()
history[-1][1] = ""
for new_text in streamer:
history[-1][1] += new_text
history[-1][1] = history[-1][1].replace(EOS_WORDS, "")
time.sleep(0.05)
yield history
def regenerate(history):
history[-1] = (history[-1][0], None)
return history
def clear_history():
return [], "", None
def build_demo():
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", visible=True, container=False
)
with gr.Blocks(title="Heron", theme=gr.themes.Base()) as demo:
with gr.Row():
with gr.Column(scale=3):
imagebox = gr.Image(type="pil")
gr.Examples(
examples=[
[
"./images/heron.png",
"What is this image?",
],
],
inputs=[imagebox, textbox],
)
with gr.Column(scale=6):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Heron Chatbot",
visible=True,
height=550,
avatar_images=("./images/user_icon.png", "./images/heron.png"),
)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=60):
submit_btn = gr.Button(value="Submit", visible=True)
with gr.Row():
regenerate_btn = gr.Button(value="Regenerate", visible=True)
clear_btn = gr.Button(value="Clear history", visible=True)
regenerate_btn.click(regenerate, chatbot, chatbot).then(
stream_bot,
[imagebox, chatbot],
[chatbot],
)
clear_btn.click(clear_history, None, [chatbot, textbox, imagebox])
textbox.submit(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then(
stream_bot,
[imagebox, chatbot],
[chatbot],
)
submit_btn.click(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then(
stream_bot,
[imagebox, chatbot],
[chatbot],
)
return demo
if __name__ == "__main__":
EOS_WORDS = "##"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_length = 512
#MODEL_NAME = "/home/kotaro/video_blip/notebooks/model_output"
MODEL_NAME = "turing-motors/heron-chat-blip-ja-stablelm-base-7b-v0"
#PROCESSOR_PATH = "turing-motors/GIT-Llama2-llava-test"
print("--")
# prepare a pretrained model
model = VideoBlipForConditionalGeneration.from_pretrained(
MODEL_NAME, torch_dtype=torch.float16, ignore_mismatched_sizes=True
)
model = model.half()
model.eval()
model.to(device)
# prepare a processor
#processor = AutoProcessor.from_pretrained(PROCESSOR_PATH, use_fast=False)
processor = VideoBlipProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1", additional_special_tokens=['▁▁'])
processor.tokenizer = tokenizer
demo = build_demo()
demo.queue(max_size=5).launch()