Spaces:
Runtime error
Runtime error
File size: 6,138 Bytes
5c6427d 2f4b8e0 5c6427d 2f4b8e0 5c6427d 64f1def b88b332 64f1def 9f9ca81 64f1def 2f4b8e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import re
import copy
import secrets
from pathlib import Path
# Constants
BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat-Int4", device_map="auto", trust_remote_code=True).eval()
def format_text(text):
"""Format text for rendering in the chat UI."""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def get_chat_response(chatbot, task_history):
"""Generate a response using the model."""
chat_query = chatbot[-1][0]
query = task_history[-1][0]
history_cp = copy.deepcopy(task_history)
full_response = ""
history_filter = []
pic_idx = 1
pre = ""
for i, (q, a) in enumerate(history_cp):
if isinstance(q, (tuple, list)):
q = f'Picture {pic_idx}: <img>{q[0]}</img>'
pre += q + '\n'
pic_idx += 1
else:
pre += q
history_filter.append((pre, a))
pre = ""
history, message = history_filter[:-1], history_filter[-1][0]
inputs = tokenizer.encode_plus(message, return_tensors='pt')
outputs = model.generate(inputs['input_ids'], max_length=150, num_beams=4, length_penalty=2.0, early_stopping=True)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
task_history.append((message, response))
chatbot.append((format_text(message), format_text(response)))
return chatbot, task_history
def handle_text_input(history, task_history, text):
"""Handle text input from the user."""
task_text = text
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
task_text = text[:-1]
history = history + [(format_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ""
def handle_file_upload(history, task_history, file):
"""Handle file upload from the user."""
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def clear_input():
"""Clear the user input."""
return gr.update(value="")
def clear_history(task_history):
"""Clear the chat history."""
task_history.clear()
return []
def handle_regeneration(chatbot, task_history):
"""Handle the regeneration of the last response."""
print("Regenerate clicked")
print("Before:", task_history, chatbot)
if not task_history:
return chatbot
item = task_history[-1]
if item[1] is None:
return chatbot
task_history[-1] = (item[0], None)
chatbot_item = chatbot.pop(-1)
if chatbot_item[0] is None:
chatbot[-1] = (chatbot[-1][0], None)
else:
chatbot.append((chatbot_item[0], None))
print("After:", task_history, chatbot)
return get_chat_response(chatbot, task_history)
chatbot = []
task_history = []
def main_function(text, image):
global chatbot, task_history
if text:
chatbot, task_history = handle_text_input(chatbot, task_history, text)
if image:
chatbot, task_history = handle_file_upload(chatbot, task_history, image)
chatbot, task_history = get_chat_response(chatbot, task_history)
formatted_response = chatbot[-1][1] # Get the latest response from the chatbot
return formatted_response
def clear_history_fn():
global chatbot, task_history
chatbot.clear()
task_history.clear()
return "History cleared."
# Custom CSS
css = '''
.gradio-container {
max-width: 800px !important;
}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown("# Qwen-VL-Chat Bot")
gr.Markdown(
"## Developed by Keyvan Hardani (Keyvven on [Twitter](https://twitter.com/Keyvven))\n"
"Special thanks to [@Artificialguybr](https://twitter.com/artificialguybr) for the inspiration from his code.\n"
"### Qwen-VL: A Multimodal Large Vision Language Model by Alibaba Cloud\n"
)
chat_interface = gr.Interface(
fn=main_function,
inputs=[
gr.components.Textbox(lines=2, label='Input'), # Update here
gr.components.Image(type='filepath', label='Upload Image') # Update here
],
outputs='text',
live=True,
layout='vertical',
theme=None,
css=css
).launch()
gr.Markdown("### Key Features:\n- **Strong Performance**: Surpasses existing LVLMs on multiple English benchmarks including Zero-shot Captioning and VQA.\n- **Multi-lingual Support**: Supports English, Chinese, and multi-lingual conversation.\n- **High Resolution**: Utilizes 448*448 resolution for fine-grained recognition and understanding.")
demo.add_button("π§Ή Clear History", clear_history_fn)
demo.launch(share=True)
|