import time from threading import Thread from llava_llama3.serve.cli import chat_llava from llava_llama3.model.builder import load_pretrained_model import gradio as gr import torch from PIL import Image import spaces # Model configuration model_id = "TheFinAI/FinLLaVA" device = "cuda:0" load_8bit = False load_4bit = False # Load the pretrained model tokenizer, llava_model, image_processor, context_len = load_pretrained_model( model_id, None, 'llava_llama3', load_8bit, load_4bit, device=device ) @spaces.GPU def bot_streaming(message, history): print(message) image = None # Check if there's an image in the current message if message["files"]: # message["files"][-1] could be a dictionary or a string if isinstance(message["files"][-1], dict): image = message["files"][-1]["path"] else: image = message["files"][-1] else: # If no image in the current message, look in the history for the last image for hist in history: if isinstance(hist[0], tuple): image = hist[0][0] # Error handling if no image is found if image is None: raise gr.Error("You need to upload an image for LLaVA to work.") # Load the image image = Image.open(image) # Generate the prompt for the model prompt = message['text'] # Use a streamer to generate the output in a streaming fashion streamer = [] # Define a function to call chat_llava in a separate thread def generate_output(): output = chat_llava( args=None, image_file=image, text=prompt, tokenizer=tokenizer, model=llava_model, image_processor=image_processor, context_len=context_len ) for new_text in output: streamer.append(new_text) # Start the generation in a separate thread thread = Thread(target=generate_output) thread.start() # Stream the output buffer = "" while thread.is_alive() or streamer: while streamer: new_text = streamer.pop(0) buffer += new_text yield buffer time.sleep(0.1) # Ensure any remaining text is yielded after the thread completes while streamer: new_text = streamer.pop(0) buffer += new_text yield buffer chatbot=gr.Chatbot(scale=1) chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) with gr.Blocks(fill_height=True, ) as demo: gr.ChatInterface( fn=bot_streaming, title="LLaVA Llama-3-8B", examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]}, {"text": "How to make this pastry?", "files": ["./baklava.png"]}], stop_btn="Stop Generation", multimodal=True, textbox=chat_input, chatbot=chatbot, ) demo.queue(api_open=False) demo.launch(show_api=False, share=False)