Spaces:
Running
Running
import logging | |
import os | |
import sys | |
import time | |
from threading import Thread | |
import gradio as gr | |
import torch | |
from transformers import AutoProcessor, StoppingCriteria, TextIteratorStreamer, LlamaTokenizer | |
os.system( | |
"git clone https://github.com/turingmotors/heron && cd heron && pip install -e ." | |
) | |
sys.path.insert(0, "./heron") | |
from heron.models.video_blip import VideoBlipForConditionalGeneration, VideoBlipProcessor | |
logger = logging.getLogger(__name__) | |
# This class is copied from llava: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py#L51-L74 | |
class KeywordsStoppingCriteria(StoppingCriteria): | |
def __init__(self, keywords, tokenizer, input_ids): | |
self.keywords = keywords | |
self.keyword_ids = [] | |
for keyword in keywords: | |
cur_keyword_ids = tokenizer(keyword).input_ids | |
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: | |
cur_keyword_ids = cur_keyword_ids[1:] | |
self.keyword_ids.append(torch.tensor(cur_keyword_ids)) | |
self.tokenizer = tokenizer | |
self.start_len = input_ids.shape[1] | |
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO | |
offset = min(output_ids.shape[1] - self.start_len, 3) | |
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] | |
for keyword_id in self.keyword_ids: | |
if output_ids[0, -keyword_id.shape[0] :] == keyword_id: | |
return True | |
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] | |
for keyword in self.keywords: | |
if keyword in outputs: | |
return True | |
return False | |
def preprocess(history, image): | |
text = "" | |
for one_history in history: | |
text += f"##human: {one_history[0]}\n##gpt: " | |
# do preprocessing | |
inputs = processor( | |
text=text, | |
images=image, | |
return_tensors="pt", | |
truncation=True, | |
) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
inputs["pixel_values"] = inputs["pixel_values"].to(device, torch.float16) | |
return inputs | |
def add_text(textbox, history): | |
history = history + [(textbox, None)] | |
return "", history | |
def stream_bot(imagebox, history): | |
# do preprocessing | |
inputs = preprocess(history, imagebox) | |
# streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
streamer = TextIteratorStreamer( | |
processor.tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True, | |
do_sample=False, | |
temperature=0.2, | |
no_repeat_ngram_size=2, | |
) | |
stopping_criteria = KeywordsStoppingCriteria( | |
[EOS_WORDS], processor.tokenizer, inputs["input_ids"] | |
) | |
inputs.update( | |
dict( | |
streamer=streamer, max_new_tokens=max_length, | |
stopping_criteria=[stopping_criteria], | |
no_repeat_ngram_size=2, | |
eos_token_id=[processor.tokenizer.pad_token_id] | |
) | |
) | |
thread = Thread(target=model.generate, kwargs=inputs) | |
thread.start() | |
history[-1][1] = "" | |
for new_text in streamer: | |
history[-1][1] += new_text | |
history[-1][1] = history[-1][1].replace(EOS_WORDS, "") | |
time.sleep(0.05) | |
yield history | |
def regenerate(history): | |
history[-1] = (history[-1][0], None) | |
return history | |
def clear_history(): | |
return [], "", None | |
def build_demo(): | |
textbox = gr.Textbox( | |
show_label=False, placeholder="Enter text and press ENTER", visible=True, container=False | |
) | |
with gr.Blocks(title="Heron", theme=gr.themes.Base()) as demo: | |
with gr.Row(): | |
with gr.Column(scale=3): | |
imagebox = gr.Image(type="pil") | |
gr.Examples( | |
examples=[ | |
[ | |
"./images/bus_kyoto.png", | |
"この道路を運転する時には何に気をつけるべきですか?", | |
], | |
[ | |
"./images/bear.png", | |
"この画像には何が写っていますか?", | |
], | |
[ | |
"./images/water_bus.png", | |
"画像には何が写っていますか?", | |
], | |
[ | |
"./images/extreme_ironing.jpg", | |
"この画像の面白い点は何ですか?", | |
], | |
[ | |
"./images/heron.png", | |
"この画像はどういう点が面白いですか?", | |
], | |
], | |
inputs=[imagebox, textbox], | |
) | |
with gr.Column(scale=6): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label="Heron Chatbot", | |
visible=True, | |
height=550, | |
avatar_images=("./images/user_icon.png", "./images/heron.png"), | |
) | |
with gr.Row(): | |
with gr.Column(scale=8): | |
textbox.render() | |
with gr.Column(scale=1, min_width=60): | |
submit_btn = gr.Button(value="Submit", visible=True) | |
with gr.Row(): | |
regenerate_btn = gr.Button(value="Regenerate", visible=True) | |
clear_btn = gr.Button(value="Clear history", visible=True) | |
regenerate_btn.click(regenerate, chatbot, chatbot).then( | |
stream_bot, | |
[imagebox, chatbot], | |
[chatbot], | |
) | |
clear_btn.click(clear_history, None, [chatbot, textbox, imagebox]) | |
textbox.submit(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then( | |
stream_bot, | |
[imagebox, chatbot], | |
[chatbot], | |
) | |
submit_btn.click(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then( | |
stream_bot, | |
[imagebox, chatbot], | |
[chatbot], | |
) | |
return demo | |
if __name__ == "__main__": | |
EOS_WORDS = "##" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
max_length = 512 | |
#MODEL_NAME = "/home/kotaro/video_blip/notebooks/model_output" | |
MODEL_NAME = "turing-motors/heron-chat-blip-ja-stablelm-base-7b-v0" | |
#PROCESSOR_PATH = "turing-motors/GIT-Llama2-llava-test" | |
print("--") | |
# prepare a pretrained model | |
model = VideoBlipForConditionalGeneration.from_pretrained( | |
MODEL_NAME, torch_dtype=torch.float16, ignore_mismatched_sizes=True | |
) | |
model = model.half() | |
model.eval() | |
model.to(device) | |
# prepare a processor | |
#processor = AutoProcessor.from_pretrained(PROCESSOR_PATH, use_fast=False) | |
processor = VideoBlipProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1", additional_special_tokens=['▁▁']) | |
processor.tokenizer = tokenizer | |
demo = build_demo() | |
demo.queue(max_size=5).launch() | |