Spaces:
Running
on
Zero
Running
on
Zero
import time | |
from threading import Thread | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer, TextStreamer | |
import spaces | |
import argparse | |
from llava_llama3.model.builder import load_pretrained_model | |
from llava_llama3.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
from llava_llama3.conversation import conv_templates, SeparatorStyle | |
from llava_llama3.utils import disable_torch_init | |
from llava_llama3.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path | |
from llava_llama3.serve.cli import chat_llava | |
import requests | |
from io import BytesIO | |
import base64 | |
import os | |
import glob | |
import pandas as pd | |
from tqdm import tqdm | |
import json | |
root_path = os.path.dirname(os.path.abspath(__file__)) | |
print(root_path) | |
os.environ['GRADIO_TEMP_DIR'] = root_path | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA") | |
parser.add_argument("--device", type=str, default="cuda") | |
parser.add_argument("--conv-mode", type=str, default="llama_3") | |
parser.add_argument("--temperature", type=float, default=0.01) | |
parser.add_argument("--max-new-tokens", type=int, default=512) | |
parser.add_argument("--load-8bit", action="store_true") | |
parser.add_argument("--load-4bit", action="store_true") | |
args = parser.parse_args() | |
# args.load_8bit = True | |
# Load model | |
tokenizer, llava_model, image_processor, context_len = load_pretrained_model( | |
args.model_path, | |
None, | |
'llava_llama3', | |
args.load_8bit, | |
args.load_4bit, | |
device=args.device) | |
def bot_streaming(message, history): | |
print ("triggered") | |
print(message) | |
image_file = None | |
if message["files"]: | |
if type(message["files"][-1]) == dict: | |
image_file = message["files"][-1]["path"] | |
else: | |
image_file = message["files"][-1] | |
else: | |
for hist in history: | |
if type(hist[0]) == tuple: | |
image_file = hist[0][0] | |
if image_file is None: | |
gr.Error("You need to upload an image for LLaVA to work.") | |
return | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
def generate(): | |
print('Running chat') | |
output = chat_llava( | |
args=args, | |
image_file=image_file, | |
text=message['text'], | |
tokenizer=tokenizer, | |
model=llava_model, | |
image_processor=image_processor, | |
context_len=context_len, | |
streamer=streamer) | |
return output | |
thread = Thread(target=generate) | |
thread.start() | |
# thread.join() | |
buffer = "" | |
# output = generate() | |
for new_text in streamer: | |
buffer += new_text | |
generated_text_without_prompt = buffer | |
time.sleep(0.06) | |
print (generated_text_without_prompt) | |
yield generated_text_without_prompt | |
chatbot = gr.Chatbot(scale=1) | |
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) | |
with gr.Blocks(fill_height=True) as demo: | |
gr.ChatInterface( | |
fn=bot_streaming, | |
title="FinLLaVA-8B Demo", | |
examples=[ | |
{"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]}, | |
{"text": "What is the spending on Healthcare in July? A. 450 B. 600 C. 520 D. 510", "files": ["image_107.png"]}, | |
{"text": "If 2012 net periodic opeb cost increased at the same pace as the pension cost, what would the estimated 2013 cost be in millions? A. 14.83333 B. 12.5 C. 15.5 D. 13.5", "files": ["image_659.png"]}, | |
], | |
description="This is a demo of FinLLaVA-8B. For more details, see our paper: https://huggingface.co/papers/2408.11878", | |
stop_btn="Stop Generation", | |
multimodal=True, | |
textbox=chat_input, | |
chatbot=chatbot, | |
) | |
demo.queue(api_open=True) | |
demo.launch(show_api=True, share=False) |