FinLLaVA / app.py
TobyYang7's picture
Update app.py
50504f5 verified
raw
history blame
1.39 kB
import gradio as gr
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
from PIL import Image
import torch
model_path = "TheFinAI/FinLLaVA"
device = "cuda"
conv_mode = "llama_3"
temperature = 0
max_new_tokens = 512
load_8bit = False
load_4bit = False
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
model_path,
None,
'llava_llama3',
load_8bit,
load_4bit,
device=device
)
def predict(image, text):
output = chat_llava(
args=None,
image_file=image,
text=text,
tokenizer=tokenizer,
model=llava_model,
image_processor=image_processor,
context_len=context_len
)
return output
chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.ChatInterface(
fn=predict,
title="FinLLaVA",
examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
{"text": "How to make this pastry?", "files": ["./baklava.png"]}],
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)