Spaces:
Runtime error
Runtime error
import time | |
from threading import Thread | |
from typing import Dict, List | |
import gradio as gr | |
import spaces | |
import torch | |
from PIL import Image | |
from transformers import ( | |
AutoProcessor, | |
MllamaForConditionalGeneration, | |
TextIteratorStreamer, | |
) | |
# Constants | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
CHECKPOINT = "toandev/Viet-Receipt-Llama-3.2-11B-Vision-Instruct" | |
# Model initialization | |
model = MllamaForConditionalGeneration.from_pretrained( | |
CHECKPOINT, torch_dtype=torch.bfloat16 | |
).to(DEVICE) | |
processor = AutoProcessor.from_pretrained(CHECKPOINT) | |
def process_chat_history(history: List) -> tuple[List[Dict], List[Image.Image]]: | |
""" | |
Process chat history to extract messages and images. | |
Args: | |
history: List of chat messages | |
Returns: | |
Tuple containing processed messages and images | |
""" | |
messages = [] | |
images = [] | |
for i, msg in enumerate(history): | |
if isinstance(msg[0], tuple): | |
messages.extend( | |
[ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": history[i + 1][0]}, | |
{"type": "image"}, | |
], | |
}, | |
{ | |
"role": "assistant", | |
"content": [{"type": "text", "text": history[i + 1][1]}], | |
}, | |
] | |
) | |
images.append(Image.open(msg[0][0]).convert("RGB")) | |
elif isinstance(history[i - 1], tuple) and isinstance(msg[0], str): | |
continue | |
elif isinstance(history[i - 1][0], str) and isinstance(msg[0], str): | |
messages.extend( | |
[ | |
{"role": "user", "content": [{"type": "text", "text": msg[0]}]}, | |
{ | |
"role": "assistant", | |
"content": [{"type": "text", "text": msg[1]}], | |
}, | |
] | |
) | |
return messages, images | |
def bot_streaming(message: Dict, history: List, max_new_tokens: int = 250): | |
""" | |
Generate streaming responses for the chatbot. | |
Args: | |
message: Current message containing text and files | |
history: Chat history | |
max_new_tokens: Maximum number of tokens to generate | |
Yields: | |
Generated text buffer | |
""" | |
text = message["text"] | |
messages, images = process_chat_history(history) | |
# Handle current message | |
if len(message["files"]) == 1: | |
image = ( | |
Image.open(message["files"][0]) | |
if isinstance(message["files"][0], str) | |
else Image.open(message["files"][0]["path"]) | |
).convert("RGB") | |
images.append(image) | |
messages.append( | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": text}, {"type": "image"}], | |
} | |
) | |
else: | |
messages.append({"role": "user", "content": [{"type": "text", "text": text}]}) | |
# Process inputs | |
texts = processor.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = ( | |
processor(text=texts, images=images, return_tensors="pt") | |
if images | |
else processor(text=texts, return_tensors="pt") | |
).to(DEVICE) | |
# Setup streaming | |
streamer = TextIteratorStreamer( | |
processor, skip_special_tokens=True, skip_prompt=True | |
) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
time.sleep(0.01) | |
yield buffer | |
demo = gr.ChatInterface( | |
fn=bot_streaming, | |
textbox=gr.MultimodalTextbox(placeholder="Ask me anything..."), | |
additional_inputs=[ | |
gr.Slider( | |
minimum=10, | |
maximum=500, | |
value=250, | |
step=10, | |
label="Maximum number of new tokens to generate", | |
) | |
], | |
examples=[ | |
[ | |
{ | |
"text": "What is the total amount in this bill?", | |
"files": ["./examples/01.jpg"], | |
}, | |
200, | |
], | |
[ | |
{ | |
"text": "What is the name of the restaurant in this bill?", | |
"files": ["./examples/02.jpg"], | |
}, | |
200, | |
], | |
], | |
cache_examples=False, | |
stop_btn="Stop", | |
fill_height=True, | |
multimodal=True, | |
type="messages", | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |