import logging import os import sys import time from threading import Thread import gradio as gr import torch from transformers import ( AutoProcessor, StoppingCriteria, TextIteratorStreamer, CLIPImageProcessor, LlamaTokenizer, ) if os.path.exists("heron") == False: os.system( "git clone https://github.com/turingmotors/heron.git" "&& export CUDA_HOME=/usr/local/cuda; pip install -e heron" ) sys.path.insert(0, "./heron") from heron.models.git_llm.git_japanese_stablelm_alpha import ( GitJapaneseStableLMAlphaConfig, GitJapaneseStableLMAlphaForCausalLM, ) from heron.models.git_llm.git_llama import GitLlamaConfig, GitLlamaForCausalLM logger = logging.getLogger(__name__) # This class is copied from llava: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py#L51-L74 class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: cur_keyword_ids = cur_keyword_ids[1:] self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO offset = min(output_ids.shape[1] - self.start_len, 3) self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] for keyword_id in self.keyword_ids: if output_ids[0, -keyword_id.shape[0] :] == keyword_id: return True outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False def preprocess(history, image): text = "" for one_history in history: text += f"##human: {one_history[0]}\n##gpt: " # do preprocessing inputs = processor( text, image, return_tensors="pt", truncation=True, ) inputs = {k: v.to(device) for k, v in inputs.items()} return inputs def add_text(textbox, history): # hard text threshold if len(textbox) > 512: textbox = textbox[:512] history = history + [(textbox, None)] return "", history title_markdown = """ # Heron Chat Demo - Model: [heron-chat-git-Llama-2-7b-v0](https://huggingface.co/turing-motors/heron-chat-git-Llama-2-7b-v0) - Codes: [Heron](https://github.com/turingmotors/heron) """ def stream_bot(imagebox, history): # do preprocessing inputs = preprocess(history, imagebox) # streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer( processor.tokenizer, skip_prompt=True, skip_special_tokens=True, ) stopping_criteria = KeywordsStoppingCriteria( [EOS_WORDS], processor.tokenizer, inputs["input_ids"] ) inputs.update( dict( streamer=streamer, max_new_tokens=max_length, stopping_criteria=[stopping_criteria], do_sample=True, temperature=0.4, no_repeat_ngram_size=2, ) ) thread = Thread(target=model.generate, kwargs=inputs) thread.start() history[-1][1] = "" for new_text in streamer: history[-1][1] += new_text while history[-1][1].endswith("#"): history[-1][1] = history[-1][1][:-1] time.sleep(0.05) yield history def regenerate(history): history[-1] = (history[-1][0], None) return history def clear_history(): return [], "", None def build_demo(): textbox = gr.Textbox( show_label=False, placeholder="Enter text and press ENTER", visible=True, container=False ) with gr.Blocks(title="Heron", theme=gr.themes.Base()) as demo: gr.Markdown(title_markdown) with gr.Row(): with gr.Column(scale=3): imagebox = gr.Image(type="pil") gr.Examples( examples=[ # [ # "./images/bus_kyoto.png", # "この道路を運転する時には何に気をつけるべきですか?", # ], # [ # "./images/bear.png", # "この画像には何が写っていますか?", # ], # [ # "./images/water_bus.png", # "画像には何が写っていますか?", # ], # [ # "./images/extreme_ironing.jpg", # "この画像の面白い点は何ですか?", # ], # [ # "./images/heron.png", # "この画像はどういう点が面白いですか?", # ], [ "./images/bus_kyoto.png", "What should you be careful of when driving on this road?", ], [ "./images/bear.png", "What is shown in this image?", ], [ "./images/water_bus.png", "What is depicted in the picture?", ], [ "./images/extreme_ironing.jpg", "What is the unusual aspect of this image?", ], [ "./images/heron.png", "What is intriguing about this picture?", ], ], inputs=[imagebox, textbox], ) with gr.Column(scale=6): chatbot = gr.Chatbot( elem_id="chatbot", label="Heron Chatbot", visible=True, height=550, avatar_images=("./images/user_icon.png", "./images/heron.png"), ) with gr.Row(): with gr.Column(scale=8): textbox.render() with gr.Column(scale=1, min_width=60): submit_btn = gr.Button(value="Submit", visible=True) with gr.Row(): regenerate_btn = gr.Button(value="Regenerate", visible=True) clear_btn = gr.Button(value="Clear history", visible=True) regenerate_btn.click(regenerate, chatbot, chatbot).then( stream_bot, [imagebox, chatbot], [chatbot], ) clear_btn.click(clear_history, None, [chatbot, textbox, imagebox]) textbox.submit(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then( stream_bot, [imagebox, chatbot], [chatbot], ) submit_btn.click(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then( stream_bot, [imagebox, chatbot], [chatbot], ) return demo if __name__ == "__main__": EOS_WORDS = "##" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") max_length = 512 vision_model_name = "openai/clip-vit-large-patch14-336" MODEL_NAME = "turing-motors/heron-chat-git-Llama-2-7b-v0" PROCESSOR_PATH = "turing-motors/heron-chat-git-Llama-2-7b-v0" # prepare a pretrained model git_config = GitLlamaConfig.from_pretrained(MODEL_NAME) git_config.set_vision_configs( num_image_with_embedding=1, vision_model_name=vision_model_name ) model = GitLlamaForCausalLM.from_pretrained( MODEL_NAME, config=git_config, torch_dtype=torch.float16 ) model.eval() model.to(device) # prepare a processor processor = AutoProcessor.from_pretrained(PROCESSOR_PATH) demo = build_demo() demo.queue(concurrency_count=1, max_size=5, api_open=False).launch(share=False)