yonatanbitton's picture
commit
17d712b
raw
history blame
2.58 kB
from datasets import load_dataset
import gradio as gr
import os
import random
wmtis = load_dataset("nlphuji/wmtis-identify")['test']
print(f"Loaded WMTIS identify, first example:")
print(wmtis[0])
dataset_size = len(wmtis) - 1
NORMAL_IMAGE = 'normal_image'
STRANGE_IMAGE = 'strange_image'
def func(index):
example = wmtis[index]
outputs = []
for normal_key in ['normal_image', 'normal_hash', 'normal_image_caption', 'rating_normal', 'comments_normal']:
if normal_key == 'comments_normal':
outputs.append(get_empty_comment_if_needed(example[normal_key]))
else:
outputs.append(example[normal_key])
for strange_key in ['strange_image', 'strange_hash', 'strange_image_caption', 'rating_strange', 'comments_strange']:
if normal_key == 'comments_normal':
outputs.append(get_empty_comment_if_needed(example[strange_key]))
else:
outputs.append(example[strange_key])
return outputs
demo = gr.Blocks()
def get_empty_comment_if_needed(item):
if item == 'nan':
return '-'
return item
with demo:
gr.Markdown("# Slide to iterate WMTIS: Normal vs. Strange Images")
with gr.Column():
slider = gr.Slider(minimum=0, maximum=dataset_size)
with gr.Row():
index = random.choice(range(0, dataset_size))
with gr.Column():
i1 = gr.Image(value=wmtis[index]["normal_image"], label='Normal Image')
t1 = gr.Textbox(value=wmtis[index]["normal_hash"], label='Image ID')
p1 = gr.Textbox(value=wmtis[index]["normal_image_caption"], label='BLIP2 Predicted Caption')
r1 = gr.Textbox(value=wmtis[index]["rating_normal"], label='Rating')
c1 = gr.Textbox(value=get_empty_comment_if_needed(wmtis[index]["comments_normal"]), label='Comments')
normal_outputs = [i1, t1, p1, r1, c1]
with gr.Column():
i2 = gr.Image(value=wmtis[index]["strange_image"], label='Strange Image')
t2 = gr.Textbox(value=wmtis[index]["strange_hash"], label='Image ID')
p2 = gr.Textbox(value=wmtis[index]["strange_image_caption"], label='BLIP2 Predicted Caption')
r2 = gr.Textbox(value=wmtis[index]["rating_strange"], label='Rating')
c2 = gr.Textbox(value=get_empty_comment_if_needed(wmtis[index]["comments_strange"]), label='Comments')
strange_outputs = [i2, t2, p2, r2, c2]
slider.change(func, inputs=[slider], outputs=normal_outputs + strange_outputs)
demo.launch()