FinLLaVA / app.py
TobyYang7's picture
Update app.py
daee25b verified
raw
history blame
4.36 kB
from threading import Thread
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
import gradio as gr
import torch
from PIL import Image
import argparse
import spaces
import os
import time
root_path = os.path.dirname(os.path.abspath(__file__))
print(root_path)
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()
# load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
args.model_path,
None,
'llava_llama3',
args.load_8bit,
args.load_4bit,
device=args.device
)
@spaces.GPU
def bot_streaming(message, history):
print(message)
image_path = None
# Check if there's an image in the current message
if message["files"]:
# message["files"][-1] could be a dictionary or a string
if isinstance(message["files"][-1], dict):
image_path = message["files"][-1]["path"]
else:
image_path = message["files"][-1]
else:
# If no image in the current message, look in the history for the last image path
for hist in history:
if isinstance(hist[0], tuple):
image_path = hist[0][0]
# Error handling if no image path is found
if image_path is None:
raise gr.Error("You need to upload an image for LLaVA to work.")
# If the image_path is a string, no need to load it into a PIL image
# Just use the path directly in the next steps
print(f"\033[91m{image_path}, {type(image_path)}\033[0m")
# Generate the prompt for the model
prompt = message['text']
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Set up the generation arguments, including the streamer
generation_kwargs = dict(
args=args,
image_file=image_path,
text=prompt,
tokenizer=tokenizer,
model=llava_model,
streamer=streamer,
image_processor=image_processor, # todo: input model name or path
context_len=context_len)
# Define the function to call `chat_llava` with the given arguments
def generate_output(generation_kwargs):
chat_llava(**generation_kwargs)
# Start the generation in a separate thread
thread = Thread(target=generate_output, kwargs=generation_kwargs)
thread.start()
# Initialize a buffer to accumulate the generated text
buffer = ""
# Allow the generation to start
time.sleep(0.5)
# Iterate over the streamer to handle the incoming text in chunks
for new_text in streamer:
# Look for the end of text token and remove it
if "<|eot_id|>" in new_text:
new_text = new_text.split("<|eot_id|>")[0]
# Add the new text to the buffer
buffer += new_text
# Remove the prompt from the generated text (if necessary)
generated_text_without_prompt = buffer[len(prompt):]
# Simulate processing time (optional)
time.sleep(0.06)
# Yield the current generated text for further processing or display
yield generated_text_without_prompt
chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="FinLLaVA",
examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
{"text": "How to make this pastry?", "files": ["./baklava.png"]},
{"text":"What is this?","files":["http://images.cocodataset.org/val2017/000000039769.jpg"]}],
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)