File size: 4,852 Bytes
0f14892
 
86cffc0
0f14892
86cffc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa81ccd
 
 
 
 
 
 
86cffc0
359e303
43d70a9
 
86cffc0
 
8130fc5
86cffc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f14892
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
from huggingface_hub import InferenceClient
import spaces

import os
import warnings
import shutil
import time
from threading import Thread

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, AutoProcessor
from transformers import TextIteratorStreamer
import torch
from dc.model import *
from dc.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from dc.conversation import conv_templates, SeparatorStyle
from PIL import Image


PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
   <img src="https://cdn-thumbnails.huggingface.co/social-thumbnails/models/mistralai/Mistral-7B-Instruct-v0.3.png" style="width: 70%; max-width: 550px; height: auto; opacity: 0.55;  "> 
   <p style="font-size: 20px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""

tokenizer = AutoTokenizer.from_pretrained('HuanjinYao/DenseConnector-v1.5-8B', use_fast=False)
model = LlavaLlamaForCausalLM.from_pretrained('HuanjinYao/DenseConnector-v1.5-8B', low_cpu_mem_usage=True,torch_dtype=torch.float16)

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
    vision_tower.load_model()
vision_tower.to(device='cuda', dtype=torch.float16)
image_processor = vision_tower.image_processor

model.to('cuda')

# model.generation_config.eos_token_id = 128009
tokenizer.unk_token = "<|reserved_special_token_0|>"
tokenizer.pad_token = tokenizer.unk_token
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]


@spaces.GPU
def bot_streaming(message, history):
    print(message)
    if message["files"]:
        # message["files"][-1] is a Dict or just a string
        if type(message["files"][-1]) == dict:
            image = message["files"][-1]["path"]
        else:
            image = message["files"][-1]
    else:
        # if there's no image uploaded for this turn, look for images in the past turns
        # kept inside tuples, take the last one
        for hist in history:
            if type(hist[0]) == tuple:
                image = hist[0][0]
    try:
        if image is None:
            # Handle the case where image is None
            gr.Error("You need to upload an image for LLaVA to work.")
    except NameError:
        # Handle the case where 'image' is not defined at all
        gr.Error("You need to upload an image for LLaVA to work.")

    
    conv = conv_templates['llama_3'].copy()
    if len(history) == 0:
        user = DEFAULT_IMAGE_TOKEN + '\n' + message['text']
    else:
        for idx, (user, assistant) in enumerate(history):
            # conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
            if idx == 0:
                user = DEFAULT_IMAGE_TOKEN + '\n' + user
            conv.append_message(conv.roles[0], user)
            conv.append_message(conv.roles[1], assistant)
    conv.append_message(conv.roles[0], user)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    image = Image.open(os.path.join(image, image_file)).convert('RGB')
    image_tensor = image_processor([image], image_processor, self.model_config)[0]

    inputs = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

    
    streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": False, "skip_prompt": True})
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False, eos_token_id = terminators)

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()


    buffer = ""
    # time.sleep(0.5)
    for new_text in streamer:
        if "<|eot_id|>" in new_text:
            new_text = new_text.split("<|eot_id|>")[0]
        buffer += new_text

        generated_text_without_prompt = buffer
        # time.sleep(0.06)
        yield generated_text_without_prompt


chatbot=gr.Chatbot(placeholder=PLACEHOLDER,scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
    gr.ChatInterface(
    fn=bot_streaming,
    title="LLaVA Llama-3-8B",
    examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
              {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
    description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
    stop_btn="Stop Generation",
    multimodal=True,
    textbox=chat_input,
    chatbot=chatbot,
    )


if __name__ == "__main__":
    demo.launch()