Spaces:
Runtime error
Runtime error
import logging | |
import os | |
import sys | |
import time | |
from threading import Thread | |
import gradio as gr | |
import torch | |
from transformers import ( | |
AutoProcessor, | |
StoppingCriteria, | |
TextIteratorStreamer, | |
CLIPImageProcessor, | |
LlamaTokenizer, | |
) | |
if os.path.exists("heron") == False: | |
os.system( | |
"git clone https://github.com/turingmotors/heron.git" | |
"&& export CUDA_HOME=/usr/local/cuda; pip install -e heron" | |
) | |
sys.path.insert(0, "./heron") | |
from heron.models.git_llm.git_japanese_stablelm_alpha import ( | |
GitJapaneseStableLMAlphaConfig, | |
GitJapaneseStableLMAlphaForCausalLM, | |
) | |
from heron.models.git_llm.git_llama import GitLlamaConfig, GitLlamaForCausalLM | |
logger = logging.getLogger(__name__) | |
# This class is copied from llava: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py#L51-L74 | |
class KeywordsStoppingCriteria(StoppingCriteria): | |
def __init__(self, keywords, tokenizer, input_ids): | |
self.keywords = keywords | |
self.keyword_ids = [] | |
for keyword in keywords: | |
cur_keyword_ids = tokenizer(keyword).input_ids | |
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: | |
cur_keyword_ids = cur_keyword_ids[1:] | |
self.keyword_ids.append(torch.tensor(cur_keyword_ids)) | |
self.tokenizer = tokenizer | |
self.start_len = input_ids.shape[1] | |
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO | |
offset = min(output_ids.shape[1] - self.start_len, 3) | |
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] | |
for keyword_id in self.keyword_ids: | |
if output_ids[0, -keyword_id.shape[0] :] == keyword_id: | |
return True | |
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] | |
for keyword in self.keywords: | |
if keyword in outputs: | |
return True | |
return False | |
def preprocess(history, image): | |
text = "" | |
for one_history in history: | |
text += f"##human: {one_history[0]}\n##gpt: " | |
# do preprocessing | |
inputs = processor( | |
text, | |
image, | |
return_tensors="pt", | |
truncation=True, | |
) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
return inputs | |
def add_text(textbox, history): | |
# hard text threshold | |
if len(textbox) > 512: | |
textbox = textbox[:512] | |
history = history + [(textbox, None)] | |
return "", history | |
title_markdown = """ | |
# Heron Chat Demo | |
- Model: [heron-chat-git-Llama-2-7b-v0](https://huggingface.co/turing-motors/heron-chat-git-Llama-2-7b-v0) | |
- Codes: [Heron](https://github.com/turingmotors/heron) | |
""" | |
def stream_bot(imagebox, history): | |
# do preprocessing | |
inputs = preprocess(history, imagebox) | |
# streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
streamer = TextIteratorStreamer( | |
processor.tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True, | |
) | |
stopping_criteria = KeywordsStoppingCriteria( | |
[EOS_WORDS], processor.tokenizer, inputs["input_ids"] | |
) | |
inputs.update( | |
dict( | |
streamer=streamer, | |
max_new_tokens=max_length, | |
stopping_criteria=[stopping_criteria], | |
do_sample=True, | |
temperature=0.4, | |
no_repeat_ngram_size=2, | |
) | |
) | |
thread = Thread(target=model.generate, kwargs=inputs) | |
thread.start() | |
history[-1][1] = "" | |
for new_text in streamer: | |
history[-1][1] += new_text | |
while history[-1][1].endswith("#"): | |
history[-1][1] = history[-1][1][:-1] | |
time.sleep(0.05) | |
yield history | |
def regenerate(history): | |
history[-1] = (history[-1][0], None) | |
return history | |
def clear_history(): | |
return [], "", None | |
def build_demo(): | |
textbox = gr.Textbox( | |
show_label=False, placeholder="Enter text and press ENTER", visible=True, container=False | |
) | |
with gr.Blocks(title="Heron", theme=gr.themes.Base()) as demo: | |
gr.Markdown(title_markdown) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
imagebox = gr.Image(type="pil") | |
gr.Examples( | |
examples=[ | |
# [ | |
# "./images/bus_kyoto.png", | |
# "ใใฎ้่ทฏใ้่ปขใใๆใซใฏไฝใซๆฐใใคใใในใใงใใ๏ผ", | |
# ], | |
# [ | |
# "./images/bear.png", | |
# "ใใฎ็ปๅใซใฏไฝใๅใฃใฆใใพใใ๏ผ", | |
# ], | |
# [ | |
# "./images/water_bus.png", | |
# "็ปๅใซใฏไฝใๅใฃใฆใใพใใ๏ผ", | |
# ], | |
# [ | |
# "./images/extreme_ironing.jpg", | |
# "ใใฎ็ปๅใฎ้ข็ฝใ็นใฏไฝใงใใ๏ผ", | |
# ], | |
# [ | |
# "./images/heron.png", | |
# "ใใฎ็ปๅใฏใฉใใใ็นใ้ข็ฝใใงใใ๏ผ", | |
# ], | |
[ | |
"./images/bus_kyoto.png", | |
"What should you be careful of when driving on this road?", | |
], | |
[ | |
"./images/bear.png", | |
"What is shown in this image?", | |
], | |
[ | |
"./images/water_bus.png", | |
"What is depicted in the picture?", | |
], | |
[ | |
"./images/extreme_ironing.jpg", | |
"What is the unusual aspect of this image?", | |
], | |
[ | |
"./images/heron.png", | |
"What is intriguing about this picture?", | |
], | |
], | |
inputs=[imagebox, textbox], | |
) | |
with gr.Column(scale=6): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label="Heron Chatbot", | |
visible=True, | |
height=550, | |
avatar_images=("./images/user_icon.png", "./images/heron.png"), | |
) | |
with gr.Row(): | |
with gr.Column(scale=8): | |
textbox.render() | |
with gr.Column(scale=1, min_width=60): | |
submit_btn = gr.Button(value="Submit", visible=True) | |
with gr.Row(): | |
regenerate_btn = gr.Button(value="Regenerate", visible=True) | |
clear_btn = gr.Button(value="Clear history", visible=True) | |
regenerate_btn.click(regenerate, chatbot, chatbot).then( | |
stream_bot, | |
[imagebox, chatbot], | |
[chatbot], | |
) | |
clear_btn.click(clear_history, None, [chatbot, textbox, imagebox]) | |
textbox.submit(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then( | |
stream_bot, | |
[imagebox, chatbot], | |
[chatbot], | |
) | |
submit_btn.click(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then( | |
stream_bot, | |
[imagebox, chatbot], | |
[chatbot], | |
) | |
return demo | |
if __name__ == "__main__": | |
EOS_WORDS = "##" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
max_length = 512 | |
vision_model_name = "openai/clip-vit-large-patch14-336" | |
MODEL_NAME = "turing-motors/heron-chat-git-Llama-2-7b-v0" | |
PROCESSOR_PATH = "turing-motors/heron-chat-git-Llama-2-7b-v0" | |
# prepare a pretrained model | |
git_config = GitLlamaConfig.from_pretrained(MODEL_NAME) | |
git_config.set_vision_configs( | |
num_image_with_embedding=1, vision_model_name=vision_model_name | |
) | |
model = GitLlamaForCausalLM.from_pretrained( | |
MODEL_NAME, config=git_config, torch_dtype=torch.float16 | |
) | |
model.eval() | |
model.to(device) | |
# prepare a processor | |
processor = AutoProcessor.from_pretrained(PROCESSOR_PATH) | |
demo = build_demo() | |
demo.queue(concurrency_count=1, max_size=5, api_open=False).launch(share=False) | |