from threading import Thread from llava_llama3.serve.cli import chat_llava from llava_llama3.model.builder import load_pretrained_model import gradio as gr import torch from PIL import Image import argparse import spaces import os import time root_path = os.path.dirname(os.path.abspath(__file__)) print(root_path) parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA") parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--conv-mode", type=str, default="llama_3") parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--load-8bit", action="store_true") parser.add_argument("--load-4bit", action="store_true") args = parser.parse_args() # load model tokenizer, llava_model, image_processor, context_len = load_pretrained_model( args.model_path, None, 'llava_llama3', args.load_8bit, args.load_4bit, device=args.device ) @spaces.GPU def bot_streaming(message, history): print(message) image_path = None # Check if there's an image in the current message if message["files"]: # message["files"][-1] could be a dictionary or a string if isinstance(message["files"][-1], dict): image_path = message["files"][-1]["path"] else: image_path = message["files"][-1] else: # If no image in the current message, look in the history for the last image path for hist in history: if isinstance(hist[0], tuple): image_path = hist[0][0] # Error handling if no image path is found if image_path is None: raise gr.Error("You need to upload an image for LLaVA to work.") # If the image_path is a string, no need to load it into a PIL image # Just use the path directly in the next steps print(f"\033[91m{image_path}, {type(image_path)}\033[0m") # Generate the prompt for the model prompt = message['text'] streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # Set up the generation arguments, including the streamer generation_kwargs = dict( args=args, image_file=image_path, text=prompt, tokenizer=tokenizer, model=llava_model, streamer=streamer, image_processor=image_processor, # todo: input model name or path context_len=context_len) # Define the function to call `chat_llava` with the given arguments def generate_output(generation_kwargs): chat_llava(**generation_kwargs) # Start the generation in a separate thread thread = Thread(target=generate_output, kwargs=generation_kwargs) thread.start() # Initialize a buffer to accumulate the generated text buffer = "" # Allow the generation to start time.sleep(0.5) # Iterate over the streamer to handle the incoming text in chunks for new_text in streamer: # Look for the end of text token and remove it if "<|eot_id|>" in new_text: new_text = new_text.split("<|eot_id|>")[0] # Add the new text to the buffer buffer += new_text # Remove the prompt from the generated text (if necessary) generated_text_without_prompt = buffer[len(prompt):] # Simulate processing time (optional) time.sleep(0.06) # Yield the current generated text for further processing or display yield generated_text_without_prompt chatbot = gr.Chatbot(scale=1) chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) with gr.Blocks(fill_height=True) as demo: gr.ChatInterface( fn=bot_streaming, title="FinLLaVA", examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]}, {"text": "How to make this pastry?", "files": ["./baklava.png"]}, {"text":"What is this?","files":["http://images.cocodataset.org/val2017/000000039769.jpg"]}], stop_btn="Stop Generation", multimodal=True, textbox=chat_input, chatbot=chatbot, ) demo.queue(api_open=False) demo.launch(show_api=False, share=False)