gigant's picture
Update app.py
7050043 verified
raw
history blame
5.33 kB
import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import io
import base64
from datasets import load_dataset
max_token_budget = 512
min_pixels = 1 * 28 * 28
max_pixels = max_token_budget * 28 * 28
processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
ds = load_dataset("gigant/tib-bench-vlm")["train"]
def segments(example):
# create a text with the <image> tokens from the timestamps of the extracted keyframes and transcript
text = ""
segment_i = 0
for i, timestamp in enumerate(example['keyframes']['timestamp']):
text += f"<image>" #f"<image {i}>"
start, end = timestamp[0], timestamp[1]
while segment_i < len(example["transcript_segments"]["seek"]) and end > example["transcript_segments"]["seek"][segment_i] * 0.01:
text += example["transcript_segments"]["text"][segment_i]
segment_i += 1
return text
def create_interleaved_html(text, slides, scale=0.4, max_width=600):
"""
Creates an HTML string with interleaved images and text segments.
The images are converted to base64 and embedded directly in the HTML.
"""
html = []
segments = text.split("<image>")
for j, segment in enumerate(segments): # Skip the first empty string bc of leading <image>
# Add the image
if j > 0:
img = slides[j - 1]
img_width = int(img.width * scale)
img_height = int(img.height * scale)
if img_width > max_width:
ratio = max_width / img_width
img_width = max_width
img_height = int(img_height * ratio)
# Convert image to base64
buffer = io.BytesIO()
img.resize((img_width, img_height)).save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
html.append(f'<img src="data:image/png;base64,{img_str}" style="max-width: {max_width}px; display: block; margin: 20px auto;">')
# Add the text segment after the image
html.append(f'<div style="white-space: pre-wrap;">{segment}</div>')
return "".join(html)
def doc_to_messages(text, slides):
content = []
segments = text.split("<image>")
for j, segment in enumerate(segments):
if j > 0:
content.append({"type": "image", "image": slides[j - 1]})
content.append({"type": "text", "text": segment})
messages = [
{
"role": "user",
"content": content,
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
print(text)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
return inputs
# Global variables to keep track of current document
current_doc_index = 0
annotations = []
def load_document(index):
"""Load a specific document from the dataset"""
if 0 <= index < len(ds):
doc = ds[index]
segments_doc = segments(doc)
return (
doc["title"],
doc["abstract"],
create_interleaved_html(segments_doc, doc["slides"], scale=0.7),
doc_to_messages(segments_doc, doc["slides"]).input_ids.shape[1],
)
return ("", "", "", "")
def get_next_document():
"""Get the next document in the dataset"""
global current_doc_index
current_doc_index = (current_doc_index + 1) % len(ds)
return load_document(current_doc_index)
def get_prev_document():
"""Get the previous document in the dataset"""
global current_doc_index
current_doc_index = (current_doc_index - 1) % len(ds)
return load_document(current_doc_index)
theme = gr.themes.Ocean()
with gr.Blocks(theme=theme) as demo:
gr.Markdown("# Slide Presentation Visualization Tool")
with gr.Row():
with gr.Column():
body = gr.HTML(max_height=400)
# Function to update the interleaved view
def update_interleaved_view(title, abstract, body, token_count):
return body
with gr.Column():
title = gr.Textbox(label="Title", interactive=False, max_lines=1)
abstract = gr.Textbox(label="Abstract", interactive=False, max_lines=8)
token_count = gr.Textbox(label=f"Token Count (Qwen2-VL with under {max_token_budget} tokens per image)", interactive=False, max_lines=1)
title.change(
fn=update_interleaved_view,
inputs=[title, abstract, body, token_count],
outputs=body,
)
# Load first document
title_val, abstract_val, body_val, token_count_val = load_document(current_doc_index)
title.value = title_val
abstract.value = abstract_val
body.value = body_val
token_count.value = str(token_count_val)
with gr.Row():
prev_button = gr.Button("Previous Document")
prev_button.click(fn=get_prev_document, inputs=[], outputs=[title, abstract, body, token_count])
next_button = gr.Button("Next Document")
next_button.click(fn=get_next_document, inputs=[], outputs=[title, abstract, body, token_count])
demo.launch()