Spaces:
Running
Running
import gradio as gr | |
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor | |
from qwen_vl_utils import process_vision_info | |
from PIL import Image | |
import io | |
import base64 | |
from datasets import load_dataset | |
max_token_budget = 512 | |
min_pixels = 1 * 28 * 28 | |
max_pixels = max_token_budget * 28 * 28 | |
processor = AutoProcessor.from_pretrained( | |
"Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels | |
) | |
ds = load_dataset("gigant/tib-bench-vlm")["train"] | |
def segments(example): | |
# create a text with the <image> tokens from the timestamps of the extracted keyframes and transcript | |
text = "" | |
segment_i = 0 | |
for i, timestamp in enumerate(example['keyframes']['timestamp']): | |
text += f"<image>" #f"<image {i}>" | |
start, end = timestamp[0], timestamp[1] | |
while segment_i < len(example["transcript_segments"]["seek"]) and end > example["transcript_segments"]["seek"][segment_i] * 0.01: | |
text += example["transcript_segments"]["text"][segment_i] | |
segment_i += 1 | |
return text | |
def create_interleaved_html(text, slides, scale=0.4, max_width=600): | |
""" | |
Creates an HTML string with interleaved images and text segments. | |
The images are converted to base64 and embedded directly in the HTML. | |
""" | |
html = [] | |
segments = text.split("<image>") | |
for j, segment in enumerate(segments): # Skip the first empty string bc of leading <image> | |
# Add the image | |
if j > 0: | |
img = slides[j - 1] | |
img_width = int(img.width * scale) | |
img_height = int(img.height * scale) | |
if img_width > max_width: | |
ratio = max_width / img_width | |
img_width = max_width | |
img_height = int(img_height * ratio) | |
# Convert image to base64 | |
buffer = io.BytesIO() | |
img.resize((img_width, img_height)).save(buffer, format="PNG") | |
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
html.append(f'<img src="data:image/png;base64,{img_str}" style="max-width: {max_width}px; display: block; margin: 20px auto;">') | |
# Add the text segment after the image | |
html.append(f'<div style="white-space: pre-wrap;">{segment}</div>') | |
return "".join(html) | |
def doc_to_messages(text, slides): | |
content = [] | |
segments = text.split("<image>") | |
for j, segment in enumerate(segments): | |
if j > 0: | |
content.append({"type": "image", "image": slides[j - 1]}) | |
content.append({"type": "text", "text": segment}) | |
messages = [ | |
{ | |
"role": "user", | |
"content": content, | |
} | |
] | |
# Preparation for inference | |
text = processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
print(text) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
) | |
return inputs | |
# Global variables to keep track of current document | |
current_doc_index = 0 | |
annotations = [] | |
def load_document(index): | |
"""Load a specific document from the dataset""" | |
if 0 <= index < len(ds): | |
doc = ds[index] | |
segments_doc = segments(doc) | |
return ( | |
doc["title"], | |
doc["abstract"], | |
create_interleaved_html(segments_doc, doc["slides"], scale=0.7), | |
doc_to_messages(segments_doc, doc["slides"]).input_ids.shape[1], | |
) | |
return ("", "", "", "") | |
def get_next_document(): | |
"""Get the next document in the dataset""" | |
global current_doc_index | |
current_doc_index = (current_doc_index + 1) % len(ds) | |
return load_document(current_doc_index) | |
def get_prev_document(): | |
"""Get the previous document in the dataset""" | |
global current_doc_index | |
current_doc_index = (current_doc_index - 1) % len(ds) | |
return load_document(current_doc_index) | |
theme = gr.themes.Ocean() | |
with gr.Blocks(theme=theme) as demo: | |
gr.Markdown("# Slide Presentation Visualization Tool") | |
with gr.Row(): | |
with gr.Column(): | |
body = gr.HTML(max_height=400) | |
# Function to update the interleaved view | |
def update_interleaved_view(title, abstract, body, token_count): | |
return body | |
with gr.Column(): | |
title = gr.Textbox(label="Title", interactive=False, max_lines=1) | |
abstract = gr.Textbox(label="Abstract", interactive=False, max_lines=8) | |
token_count = gr.Textbox(label=f"Token Count (Qwen2-VL with under {max_token_budget} tokens per image)", interactive=False, max_lines=1) | |
title.change( | |
fn=update_interleaved_view, | |
inputs=[title, abstract, body, token_count], | |
outputs=body, | |
) | |
# Load first document | |
title_val, abstract_val, body_val, token_count_val = load_document(current_doc_index) | |
title.value = title_val | |
abstract.value = abstract_val | |
body.value = body_val | |
token_count.value = str(token_count_val) | |
with gr.Row(): | |
prev_button = gr.Button("Previous Document") | |
prev_button.click(fn=get_prev_document, inputs=[], outputs=[title, abstract, body, token_count]) | |
next_button = gr.Button("Next Document") | |
next_button.click(fn=get_next_document, inputs=[], outputs=[title, abstract, body, token_count]) | |
demo.launch() |