import logging import os import sys import time from threading import Thread import gradio as gr import torch from transformers import AutoProcessor, StoppingCriteria, TextIteratorStreamer, LlamaTokenizer os.system( "git clone && cd heron && pip install -e ." ) sys.path.insert(0, "./heron") from heron.models.video_blip import VideoBlipForConditionalGeneration, VideoBlipProcessor logger = logging.getLogger(__name__) # This class is copied from llava: class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: cur_keyword_ids = cur_keyword_ids[1:] self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO offset = min(output_ids.shape[1] - self.start_len, 3) self.keyword_ids = [ for keyword_id in self.keyword_ids] for keyword_id in self.keyword_ids: if output_ids[0, -keyword_id.shape[0] :] == keyword_id: return True outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False def preprocess(history, image): text = "" for one_history in history: text += f"##human: {one_history[0]}\n##gpt: " # do preprocessing inputs = processor( text=text, images=image, return_tensors="pt", truncation=True, ) inputs = {k: for k, v in inputs.items()} inputs["pixel_values"] = inputs["pixel_values"].to(device, torch.float16) return inputs def add_text(textbox, history): history = history + [(textbox, None)] return "", history def stream_bot(imagebox, history): # do preprocessing inputs = preprocess(history, imagebox) # streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer( processor.tokenizer, skip_prompt=True, skip_special_tokens=True, do_sample=False, temperature=0.2, no_repeat_ngram_size=2, ) stopping_criteria = KeywordsStoppingCriteria( [EOS_WORDS], processor.tokenizer, inputs["input_ids"] ) inputs.update( dict( streamer=streamer, max_new_tokens=max_length, stopping_criteria=[stopping_criteria], no_repeat_ngram_size=2, eos_token_id=[processor.tokenizer.pad_token_id] ) ) thread = Thread(target=model.generate, kwargs=inputs) thread.start() history[-1][1] = "" for new_text in streamer: history[-1][1] += new_text history[-1][1] = history[-1][1].replace(EOS_WORDS, "") time.sleep(0.05) yield history def regenerate(history): history[-1] = (history[-1][0], None) return history def clear_history(): return [], "", None def build_demo(): textbox = gr.Textbox( show_label=False, placeholder="Enter text and press ENTER", visible=True, container=False ) with gr.Blocks(title="Heron", theme=gr.themes.Base()) as demo: with gr.Row(): with gr.Column(scale=3): imagebox = gr.Image(type="pil") gr.Examples( examples=[ [ "./images/heron.png", "What is this image?", ], ], inputs=[imagebox, textbox], ) with gr.Column(scale=6): chatbot = gr.Chatbot( elem_id="chatbot", label="Heron Chatbot", visible=True, height=550, avatar_images=("./images/user_icon.png", "./images/heron.png"), ) with gr.Row(): with gr.Column(scale=8): textbox.render() with gr.Column(scale=1, min_width=60): submit_btn = gr.Button(value="Submit", visible=True) with gr.Row(): regenerate_btn = gr.Button(value="Regenerate", visible=True) clear_btn = gr.Button(value="Clear history", visible=True), chatbot, chatbot).then( stream_bot, [imagebox, chatbot], [chatbot], ), None, [chatbot, textbox, imagebox]) textbox.submit(add_text, [textbox, chatbot], [textbox, chatbot], queue=False).then( stream_bot, [imagebox, chatbot], [chatbot], ), [textbox, chatbot], [textbox, chatbot], queue=False).then( stream_bot, [imagebox, chatbot], [chatbot], ) return demo if __name__ == "__main__": EOS_WORDS = "##" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") max_length = 512 #MODEL_NAME = "/home/kotaro/video_blip/notebooks/model_output" MODEL_NAME = "turing-motors/heron-chat-blip-ja-stablelm-base-7b-v0" #PROCESSOR_PATH = "turing-motors/GIT-Llama2-llava-test" print("--") # prepare a pretrained model model = VideoBlipForConditionalGeneration.from_pretrained( MODEL_NAME, torch_dtype=torch.float16, ignore_mismatched_sizes=True ) model = model.half() model.eval() # prepare a processor #processor = AutoProcessor.from_pretrained(PROCESSOR_PATH, use_fast=False) processor = VideoBlipProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1", additional_special_tokens=['▁▁']) processor.tokenizer = tokenizer demo = build_demo() demo.queue(max_size=5).launch()