File size: 5,669 Bytes
0f14892
 
86cffc0
0f14892
86cffc0
 
 
 
 
 
 
 
 
 
910652f
86cffc0
 
a637de1
86cffc0
 
5886f69
 
 
 
 
 
291ddd0
fa81ccd
86cffc0
359e303
43d70a9
 
86cffc0
 
8130fc5
86cffc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf5acef
fdacea4
86cffc0
 
 
1d64fb8
86cffc0
b99ba25
86cffc0
 
 
 
 
9986a64
86cffc0
 
dd694ec
baaa59a
86cffc0
8abdcd2
a637de1
86cffc0
f34e2ac
 
ec4c4f7
86cffc0
95be3d0
f34e2ac
f966a64
f34e2ac
dd694ec
86cffc0
471412e
 
5774921
86cffc0
dd694ec
86cffc0
5886f69
86cffc0
 
 
 
 
 
baaa59a
86cffc0
 
 
 
 
 
 
 
 
b7f9d75
86cffc0
 
 
 
8abdcd2
5886f69
 
86cffc0
 
 
 
5886f69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86cffc0
0f14892
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
from huggingface_hub import InferenceClient
import spaces

import os
import warnings
import shutil
import time
from threading import Thread

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, AutoProcessor
from transformers import TextIteratorStreamer
import torch
from dc.model import *
from dc.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from dc.conversation import conv_templates, SeparatorStyle
from PIL import Image
from dc.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path


PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
    <p style="font-size: 20px; margin-bottom: 2px; opacity: 0.65;">Upload an image to start the conversation.</p>
   <p style="font-size: 20px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""


tokenizer = AutoTokenizer.from_pretrained('HuanjinYao/DenseConnector-v1.5-8B', use_fast=False)
model = LlavaLlamaForCausalLM.from_pretrained('HuanjinYao/DenseConnector-v1.5-8B', low_cpu_mem_usage=True,torch_dtype=torch.float16)

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
    vision_tower.load_model()
vision_tower.to(device='cuda', dtype=torch.float16)
image_processor = vision_tower.image_processor

model.to('cuda')

# model.generation_config.eos_token_id = 128009
tokenizer.unk_token = "<|reserved_special_token_0|>"
tokenizer.pad_token = tokenizer.unk_token
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]


@spaces.GPU
def bot_streaming(message, history):
    print(message)
    if message["files"]:
        # message["files"][-1] is a Dict or just a string
        if type(message["files"][-1]) == dict:
            image = message["files"][-1]["path"]
        else:
            image = message["files"][-1]
    else:
        # if there's no image uploaded for this turn, look for images in the past turns
        # kept inside tuples, take the last one
        for hist in history:
            if type(hist[0]) == tuple:
                image = hist[0][0]
    try:
        if image is None:
            # Handle the case where image is None
            gr.Error("You need to upload an image for LLaVA to work.")
    except NameError:
        # Handle the case where 'image' is not defined at all
        gr.Error("You need to upload an image for LLaVA to work.")

    print('history', history)

    
    conv = conv_templates['llama_3'].copy()
    if len(history) == 0:
        message['text'] = DEFAULT_IMAGE_TOKEN + '\n' + message['text']
    else:
        for idx, (user, assistant) in enumerate(history[1:]):
            # conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
            if idx == 0:
                user = DEFAULT_IMAGE_TOKEN + '\n' + user
            conv.append_message(conv.roles[0], user)
            conv.append_message(conv.roles[1], assistant)
    conv.append_message(conv.roles[0], message['text'])
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    
    print(prompt)

    image = Image.open(image).convert('RGB')
    image_tensor = process_images([image], image_processor, model.config)[0]

    

    inputs = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)

    image_tensor = image_tensor.unsqueeze(0)
    image_tensor = image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True)
    inputs = inputs.to(device='cuda', non_blocking=True)


    
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    
    generation_kwargs = dict(inputs=inputs, images=image_tensor, streamer=streamer, max_new_tokens=1024, do_sample=False, eos_token_id = terminators)


    thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
    thread.start()


    buffer = ""
    # time.sleep(0.5)
    for new_text in streamer:
        print('new_text', new_text)
        if "<|eot_id|>" in new_text:
            new_text = new_text.split("<|eot_id|>")[0]
        buffer += new_text

        generated_text_without_prompt = buffer
        # time.sleep(0.06)
        yield generated_text_without_prompt


chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label=f"Chat with Dense Connector")
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
    gr.ChatInterface(
    fn=bot_streaming,
    title="DenseConnector-v1.5-8B",
    examples=[{"text": "Describe this movie.", "files": ["./Interstellar.jpg"]}],
    description="Try [DenseConnector-v1.5-8B](https://huggingface.co/HuanjinYao/DenseConnector-v1.5-8B). Upload an image and start chatting about it. If you don't upload an image, you will receive an error.",
    stop_btn="Stop Generation",
    multimodal=True,
    textbox=chat_input,
    chatbot=chatbot,
    additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
    additional_inputs=[
        gr.Slider(minimum=0,
                  maximum=1, 
                  step=0.1,
                  value=0.95, 
                  label="Temperature", 
                  render=False),
        gr.Slider(minimum=128, 
                  maximum=4096,
                  step=1,
                  value=512, 
                  label="Max new tokens", 
                  render=False ),
        ],
    )


if __name__ == "__main__":
    demo.launch()