Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,034 Bytes
afff347 ea37c27 afff347 ea37c27 ca317b2 d5bf1ae 29af230 d5bf1ae 29af230 140504c 29af230 25f126e 29af230 ee668ff d5bf1ae 29af230 ea37c27 d5bf1ae 29af230 ea37c27 29af230 ea37c27 d5bf1ae 29af230 afff347 29af230 cec0b15 29af230 afff347 ea37c27 29af230 5b853cd ea37c27 d5bf1ae 5b853cd d5bf1ae ea37c27 29af230 d5bf1ae 29af230 d5bf1ae 29af230 f344ce6 6a2e015 29af230 ee668ff d5bf1ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import time
from threading import Thread
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer, TextStreamer
import spaces
import argparse
from llava_llama3.model.builder import load_pretrained_model
from llava_llama3.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava_llama3.conversation import conv_templates, SeparatorStyle
from llava_llama3.utils import disable_torch_init
from llava_llama3.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from llava_llama3.serve.cli import chat_llava
import requests
from io import BytesIO
import base64
import os
import glob
import pandas as pd
from tqdm import tqdm
import json
root_path = os.path.dirname(os.path.abspath(__file__))
print(f'\033[92m{root_path}\033[0m')
os.environ['GRADIO_TEMP_DIR'] = root_path
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0.01)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()
args.load_8bit = True
# Load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
args.model_path,
None,
'llava_llama3',
args.load_8bit,
args.load_4bit,
device=args.device)
@spaces.GPU
def bot_streaming(message, history):
print(message)
image_file = None
if message["files"]:
if type(message["files"][-1]) == dict:
image_file = message["files"][-1]["path"]
else:
image_file = message["files"][-1]
else:
for hist in history:
if type(hist[0]) == tuple:
image_file = hist[0][0]
if image_file is None:
gr.Error("You need to upload an image for LLaVA to work.")
return
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
def generate():
print('\033[92mRunning chat\033[0m')
output = chat_llava(
args=args,
image_file=image_file,
text=message['text'],
tokenizer=tokenizer,
model=llava_model,
image_processor=image_processor,
context_len=context_len,
streamer=streamer)
return output
thread = Thread(target=generate)
thread.start()
# thread.join()
buffer = ""
# output = generate()
for new_text in streamer:
buffer += new_text
generated_text_without_prompt = buffer
time.sleep(0.06)
yield generated_text_without_prompt
chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="FinLLaVA Demo",
examples=[
{"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]},
{"text": "What is the spending on Healthcare in July? A. 450 B. 600 C. 520 D. 510", "files": ["image_107.png"]},
{"text": "If 2012 net periodic opeb cost increased at the same pace as the pension cost, what would the estimated 2013 cost be in millions? A. 14.83333 B. 12.5 C. 15.5 D. 13.5", "files": ["image_659.png"]},
],
description="",
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False) |