File size: 4,034 Bytes
afff347
ea37c27
afff347
ea37c27
ca317b2
d5bf1ae
29af230
d5bf1ae
 
29af230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140504c
29af230
 
 
 
 
25f126e
 
29af230
 
 
 
 
 
 
 
ee668ff
d5bf1ae
 
 
29af230
ea37c27
d5bf1ae
29af230
ea37c27
29af230
ea37c27
 
d5bf1ae
29af230
 
 
afff347
29af230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec0b15
29af230
afff347
ea37c27
29af230
5b853cd
ea37c27
d5bf1ae
5b853cd
d5bf1ae
ea37c27
29af230
d5bf1ae
29af230
d5bf1ae
29af230
 
 
 
f344ce6
 
6a2e015
 
29af230
 
 
 
 
 
ee668ff
 
d5bf1ae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import time
from threading import Thread

import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer, TextStreamer

import spaces
import argparse

from llava_llama3.model.builder import load_pretrained_model
from llava_llama3.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava_llama3.conversation import conv_templates, SeparatorStyle
from llava_llama3.utils import disable_torch_init
from llava_llama3.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from llava_llama3.serve.cli import chat_llava

import requests
from io import BytesIO
import base64
import os
import glob
import pandas as pd
from tqdm import tqdm
import json

root_path = os.path.dirname(os.path.abspath(__file__))
print(f'\033[92m{root_path}\033[0m')
os.environ['GRADIO_TEMP_DIR'] = root_path

parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0.01)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()

args.load_8bit = True

# Load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
    args.model_path, 
    None, 
    'llava_llama3', 
    args.load_8bit, 
    args.load_4bit, 
    device=args.device)

@spaces.GPU
def bot_streaming(message, history):
    print(message)
    image_file = None
    if message["files"]:
        if type(message["files"][-1]) == dict:
            image_file = message["files"][-1]["path"]
        else:
            image_file = message["files"][-1]
    else:
        for hist in history:
            if type(hist[0]) == tuple:
                image_file = hist[0][0]
                
    if image_file is None:
        gr.Error("You need to upload an image for LLaVA to work.")
        return
    
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    def generate():
        print('\033[92mRunning chat\033[0m')
        output = chat_llava(
                    args=args,
                    image_file=image_file,
                    text=message['text'],
                    tokenizer=tokenizer,
                    model=llava_model,
                    image_processor=image_processor,
                    context_len=context_len,
                    streamer=streamer)
        return output

    thread = Thread(target=generate)
    thread.start()
    # thread.join()

    buffer = ""
    # output = generate()
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer
        time.sleep(0.06)
        yield generated_text_without_prompt

chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
    gr.ChatInterface(
        fn=bot_streaming,
        title="FinLLaVA Demo",
        examples=[
            {"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]},
            {"text": "What is the spending on Healthcare in July? A. 450 B. 600 C. 520 D. 510", "files": ["image_107.png"]},
            {"text": "If 2012 net periodic opeb cost increased at the same pace as the pension cost, what would the estimated 2013 cost be in millions? A. 14.83333 B. 12.5 C. 15.5 D. 13.5", "files": ["image_659.png"]},


        ],
        description="",
        stop_btn="Stop Generation",
        multimodal=True,
        textbox=chat_input,
        chatbot=chatbot,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)