FinLLaVA / app.py
TobyYang7's picture
Update app.py
7dc477a verified
raw
history blame
3.67 kB
from threading import Thread
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
import gradio as gr
import torch
from PIL import Image
import argparse
import spaces
import os
import time
root_path = os.path.dirname(os.path.abspath(__file__))
print(root_path)
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()
# load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
args.model_path,
None,
'llava_llama3',
args.load_8bit,
args.load_4bit,
device=args.device
)
@spaces.GPU
def bot_streaming(message, history):
print(message)
image = None
# Check if there's an image in the current message
if message["files"]:
# message["files"][-1] could be a dictionary or a string
if isinstance(message["files"][-1], dict):
image = message["files"][-1]["path"]
else:
image = message["files"][-1]
else:
# If no image in the current message, look in the history for the last image
for hist in history:
if isinstance(hist[0], tuple):
image = hist[0][0]
# Error handling if no image is found
if image is None:
raise gr.Error("You need to upload an image for LLaVA to work.")
# Load the image if it's a path, otherwise use the existing PIL image
if isinstance(image, str):
image = Image.open(image).convert('RGB')
# Generate the prompt for the model
prompt = message['text']
# Use a streamer to generate the output in a streaming fashion
streamer = []
# Define a function to call chat_llava in a separate thread
def generate_output():
output = chat_llava(
args=args,
image_file=image,
text=prompt,
tokenizer=tokenizer,
model=llava_model,
image_processor=image_processor,
context_len=context_len
)
for new_text in output:
streamer.append(new_text)
# Start the generation in a separate thread
thread = Thread(target=generate_output)
thread.start()
# Stream the output
buffer = ""
while thread.is_alive() or streamer:
while streamer:
new_text = streamer.pop(0)
buffer += new_text
yield buffer
time.sleep(0.1)
# Ensure any remaining text is yielded after the thread completes
while streamer:
new_text = streamer.pop(0)
buffer += new_text
yield buffer
chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="FinLLaVA",
examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
{"text": "How to make this pastry?", "files": ["./baklava.png"]}],
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)