Spaces:
Running
Running
File size: 7,388 Bytes
fa44dd0 7020a09 fa44dd0 ccd8557 fa44dd0 ccd8557 fa44dd0 efbb3ca fa44dd0 ddf556f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import logging
import os
import sys
import time
from threading import Thread
import gradio as gr
import torch
from transformers import AutoProcessor, StoppingCriteria, TextIteratorStreamer, LlamaTokenizer
os.system(
"git clone https://github.com/turingmotors/heron && cd heron && pip install -e ."
)
sys.path.insert(0, "./heron")
from heron.models.video_blip import VideoBlipForConditionalGeneration, VideoBlipProcessor
logger = logging.getLogger(__name__)
# This class is copied from llava: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py#L51-L74
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
offset = min(output_ids.shape[1] - self.start_len, 3)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def preprocess(history, image):
text = ""
for one_history in history:
text += f"##human: {one_history[0]}\n##gpt: "
# do preprocessing
inputs = processor(
text=text,
images=image,
return_tensors="pt",
truncation=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
inputs["pixel_values"] = inputs["pixel_values"].to(device, torch.float16)
return inputs
def add_text(textbox, history):
history = history + [(textbox, None)]
return "", history
def stream_bot(imagebox, history):
# do preprocessing
inputs = preprocess(history, imagebox)
# streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
do_sample=False,
temperature=0.2,
no_repeat_ngram_size=2,
)
stopping_criteria = KeywordsStoppingCriteria(
[EOS_WORDS], processor.tokenizer, inputs["input_ids"]
)
inputs.update(
dict(
streamer=streamer, max_new_tokens=max_length,
stopping_criteria=[stopping_criteria],
no_repeat_ngram_size=2,
eos_token_id=[processor.tokenizer.pad_token_id]
)
)
thread = Thread(target=model.generate, kwargs=inputs)
thread.start()
history[-1][1] = ""
for new_text in streamer:
history[-1][1] += new_text
history[-1][1] = history[-1][1].replace(EOS_WORDS, "")
time.sleep(0.05)
yield history
def regenerate(history):
history[-1] = (history[-1][0], None)
return history
def clear_history():
return [], "", None
def build_demo():
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", visible=True, container=False
)
with gr.Blocks(title="Heron", theme=gr.themes.Base()) as demo:
with gr.Row():
with gr.Column(scale=3):
imagebox = gr.Image(type="pil")
gr.Examples(
examples=[
[
"./images/bus_kyoto.png",
"この道路を運転する時には何に気をつけるべきですか?",
],
[
"./images/bear.png",
"この画像には何が写っていますか?",
],
[
"./images/water_bus.png",
"画像には何が写っていますか?",
],
[
"./images/extreme_ironing.jpg",
"この画像の面白い点は何ですか?",
],
[
"./images/heron.png",
"この画像はどういう点が面白いですか?",
],
],
inputs=[imagebox, textbox],
)
with gr.Column(scale=6):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Heron Chatbot",
visible=True,
height=550,
avatar_images=("./images/user_icon.png", "./images/heron.png"),
)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=60):
submit_btn = gr.Button(value="Submit", visible=True)
with gr.Row():
regenerate_btn = gr.Button(value="Regenerate", visible=True)
clear_btn = gr.Button(value="Clear history", visible=True)
regenerate_btn.click(regenerate, chatbot, chatbot).then(
stream_bot,
[imagebox, chatbot],
[chatbot],
)
clear_btn.click(clear_history, None, [chatbot, textbox, imagebox])
textbox.submit(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then(
stream_bot,
[imagebox, chatbot],
[chatbot],
)
submit_btn.click(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then(
stream_bot,
[imagebox, chatbot],
[chatbot],
)
return demo
if __name__ == "__main__":
EOS_WORDS = "##"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_length = 512
#MODEL_NAME = "/home/kotaro/video_blip/notebooks/model_output"
MODEL_NAME = "turing-motors/heron-chat-blip-ja-stablelm-base-7b-v0"
#PROCESSOR_PATH = "turing-motors/GIT-Llama2-llava-test"
print("--")
# prepare a pretrained model
model = VideoBlipForConditionalGeneration.from_pretrained(
MODEL_NAME, torch_dtype=torch.float16, ignore_mismatched_sizes=True
)
model = model.half()
model.eval()
model.to(device)
# prepare a processor
#processor = AutoProcessor.from_pretrained(PROCESSOR_PATH, use_fast=False)
processor = VideoBlipProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1", additional_special_tokens=['▁▁'])
processor.tokenizer = tokenizer
demo = build_demo()
demo.queue(max_size=5).launch()
|