from datasets import load_dataset import gradio as gr import os import random wmtis = load_dataset("nlphuji/wmtis-identify")['test'] print(f"Loaded WMTIS identify, first example:") print(wmtis[0]) dataset_size = len(wmtis) - 1 NORMAL_IMAGE = 'normal_image' STRANGE_IMAGE = 'strange_image' def func(index): example = wmtis[index] outputs = [] for normal_key in ['normal_image', 'normal_hash', 'normal_image_caption', 'rating_normal', 'comments_normal']: if normal_key == 'comments_normal': outputs.append(get_empty_comment_if_needed(example[normal_key])) else: outputs.append(example[normal_key]) for strange_key in ['strange_image', 'strange_hash', 'strange_image_caption', 'rating_strange', 'comments_strange']: if normal_key == 'comments_normal': outputs.append(get_empty_comment_if_needed(example[strange_key])) else: outputs.append(example[strange_key]) return outputs demo = gr.Blocks() def get_empty_comment_if_needed(item): if item == 'nan': return '-' return item with demo: gr.Markdown("# Slide to iterate WMTIS: Normal vs. Strange Images") with gr.Column(): slider = gr.Slider(minimum=0, maximum=dataset_size) with gr.Row(): index = random.choice(range(0, dataset_size)) with gr.Column(): i1 = gr.Image(value=wmtis[index]["normal_image"], label='Normal Image') t1 = gr.Textbox(value=wmtis[index]["normal_hash"], label='Image ID') p1 = gr.Textbox(value=wmtis[index]["normal_image_caption"], label='BLIP2 Predicted Caption') r1 = gr.Textbox(value=wmtis[index]["rating_normal"], label='Rating') c1 = gr.Textbox(value=get_empty_comment_if_needed(wmtis[index]["comments_normal"]), label='Comments') normal_outputs = [i1, t1, p1, r1, c1] with gr.Column(): i2 = gr.Image(value=wmtis[index]["strange_image"], label='Strange Image') t2 = gr.Textbox(value=wmtis[index]["strange_hash"], label='Image ID') p2 = gr.Textbox(value=wmtis[index]["strange_image_caption"], label='BLIP2 Predicted Caption') r2 = gr.Textbox(value=wmtis[index]["rating_strange"], label='Rating') c2 = gr.Textbox(value=get_empty_comment_if_needed(wmtis[index]["comments_strange"]), label='Comments') strange_outputs = [i2, t2, p2, r2, c2] slider.change(func, inputs=[slider], outputs=normal_outputs + strange_outputs) demo.launch()