FinLLaVA / app.py
TobyYang7's picture
Update app.py
cec0b15 verified
raw
history blame
3.05 kB
import time
from threading import Thread
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
import gradio as gr
import torch
from PIL import Image
import spaces
# Model configuration
model_id = "TheFinAI/FinLLaVA"
device = "cuda:0"
load_8bit = False
load_4bit = False
# Load the pretrained model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
model_id,
None,
'llava_llama3',
load_8bit,
load_4bit,
device=device
)
@spaces.GPU
def bot_streaming(message, history):
print(message)
image = None
# Check if there's an image in the current message
if message["files"]:
# message["files"][-1] could be a dictionary or a string
if isinstance(message["files"][-1], dict):
image = message["files"][-1]["path"]
else:
image = message["files"][-1]
else:
# If no image in the current message, look in the history for the last image
for hist in history:
if isinstance(hist[0], tuple):
image = hist[0][0]
# Error handling if no image is found
if image is None:
raise gr.Error("You need to upload an image for LLaVA to work.")
# Load the image
image = Image.open(image)
# Generate the prompt for the model
prompt = message['text']
# Use a streamer to generate the output in a streaming fashion
streamer = []
# Define a function to call chat_llava in a separate thread
def generate_output():
output = chat_llava(
args=None,
image_file=image,
text=prompt,
tokenizer=tokenizer,
model=llava_model,
image_processor=image_processor,
context_len=context_len
)
for new_text in output:
streamer.append(new_text)
# Start the generation in a separate thread
thread = Thread(target=generate_output)
thread.start()
# Stream the output
buffer = ""
while thread.is_alive() or streamer:
while streamer:
new_text = streamer.pop(0)
buffer += new_text
yield buffer
time.sleep(0.1)
# Ensure any remaining text is yielded after the thread completes
while streamer:
new_text = streamer.pop(0)
buffer += new_text
yield buffer
chatbot=gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="LLaVA Llama-3-8B",
examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
{"text": "How to make this pastry?", "files": ["./baklava.png"]}],
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)