Spaces:
Running
on
Zero
Running
on
Zero
from threading import Thread | |
from llava_llama3.serve.cli import chat_llava | |
from llava_llama3.model.builder import load_pretrained_model | |
import gradio as gr | |
import torch | |
from PIL import Image | |
import argparse | |
import spaces | |
import os | |
import time | |
root_path = os.path.dirname(os.path.abspath(__file__)) | |
print(root_path) | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA") | |
parser.add_argument("--device", type=str, default="cuda:0") | |
parser.add_argument("--conv-mode", type=str, default="llama_3") | |
parser.add_argument("--temperature", type=float, default=0) | |
parser.add_argument("--max-new-tokens", type=int, default=512) | |
parser.add_argument("--load-8bit", action="store_true") | |
parser.add_argument("--load-4bit", action="store_true") | |
args = parser.parse_args() | |
# load model | |
tokenizer, llava_model, image_processor, context_len = load_pretrained_model( | |
args.model_path, | |
None, | |
'llava_llama3', | |
args.load_8bit, | |
args.load_4bit, | |
device=args.device | |
) | |
def bot_streaming(message, history): | |
print(message) | |
image = None | |
# Check if there's an image in the current message | |
if message["files"]: | |
# message["files"][-1] could be a dictionary or a string | |
if isinstance(message["files"][-1], dict): | |
image = message["files"][-1]["path"] | |
else: | |
image = message["files"][-1] | |
else: | |
# If no image in the current message, look in the history for the last image | |
for hist in history: | |
if isinstance(hist[0], tuple): | |
image = hist[0][0] | |
# Error handling if no image is found | |
if image is None: | |
raise gr.Error("You need to upload an image for LLaVA to work.") | |
# Load the image if it's a path, otherwise use the existing PIL image | |
if isinstance(image, str): | |
image = Image.open(image).convert('RGB') | |
# Generate the prompt for the model | |
prompt = message['text'] | |
# Use a streamer to generate the output in a streaming fashion | |
streamer = [] | |
# Define a function to call chat_llava in a separate thread | |
def generate_output(): | |
output = chat_llava( | |
args=args, | |
image_file=image, | |
text=prompt, | |
tokenizer=tokenizer, | |
model=llava_model, | |
image_processor=image_processor, | |
context_len=context_len | |
) | |
for new_text in output: | |
streamer.append(new_text) | |
# Start the generation in a separate thread | |
thread = Thread(target=generate_output) | |
thread.start() | |
# Stream the output | |
buffer = "" | |
while thread.is_alive() or streamer: | |
while streamer: | |
new_text = streamer.pop(0) | |
buffer += new_text | |
yield buffer | |
time.sleep(0.1) | |
# Ensure any remaining text is yielded after the thread completes | |
while streamer: | |
new_text = streamer.pop(0) | |
buffer += new_text | |
yield buffer | |
chatbot = gr.Chatbot(scale=1) | |
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) | |
with gr.Blocks(fill_height=True) as demo: | |
gr.ChatInterface( | |
fn=bot_streaming, | |
title="FinLLaVA", | |
examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]}, | |
{"text": "How to make this pastry?", "files": ["./baklava.png"]}], | |
stop_btn="Stop Generation", | |
multimodal=True, | |
textbox=chat_input, | |
chatbot=chatbot, | |
) | |
demo.queue(api_open=False) | |
demo.launch(show_api=False, share=False) |