File size: 3,462 Bytes
72fb101
 
4e4df51
 
 
 
c4454c9
4e4df51
 
 
72fb101
 
 
 
 
 
3b91aca
4e4df51
 
72fb101
 
be53140
4e4df51
9febe95
4e4df51
 
 
 
 
 
72fb101
4e4df51
72fb101
 
 
 
 
 
be53140
72fb101
 
 
 
 
 
 
 
 
435f95e
72fb101
4e4df51
 
435f95e
4e4df51
 
435f95e
 
 
 
 
4e4df51
435f95e
4e4df51
435f95e
 
 
4e4df51
435f95e
 
 
 
 
72fb101
 
4e4df51
 
 
 
 
72fb101
435f95e
4e4df51
 
 
 
 
 
 
72fb101
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import random

import gradio as gr
from datasets import load_dataset

whoops = load_dataset("nlphuji/whoops")['test']
print(f"Loaded WMTIS, first example:")
print(whoops[0])
dataset_size = len(whoops)
print(f"all dataset size: {dataset_size}")

IMAGE = 'image'
IMAGE_DESIGNER = 'image_designer'
DESIGNER_EXPLANATION = 'designer_explanation'
CROWD_CAPTIONS = 'crowd_captions'
CROWD_EXPLANATIONS = 'crowd_explanations'
CROWD_UNDERSPECIFIED_CAPTIONS = 'crowd_underspecified_captions'
SELECTED_CAPTION = 'selected_caption'
COMMONSENSE_CATEGORY = 'commonsense_category'
QA = 'question_answering_pairs'
IMAGE_ID = 'image_id'
left_side_columns = [IMAGE]
right_side_columns = [x for x in whoops.features.keys() if x not in left_side_columns and x not in [QA]]
enumerate_cols = [CROWD_CAPTIONS, CROWD_EXPLANATIONS, CROWD_UNDERSPECIFIED_CAPTIONS]
emoji_to_label = {IMAGE_DESIGNER: '🎨, πŸ§‘β€πŸŽ¨, πŸ’»', DESIGNER_EXPLANATION: 'πŸ’‘, πŸ€”, πŸ§‘β€πŸŽ¨',
                  CROWD_CAPTIONS: 'πŸ‘₯, πŸ’¬, πŸ“', CROWD_EXPLANATIONS: 'πŸ‘₯, πŸ’‘, πŸ€”', CROWD_UNDERSPECIFIED_CAPTIONS: 'πŸ‘₯, πŸ’¬, πŸ‘Ž',
                  QA: '❓, πŸ€”, πŸ’‘', IMAGE_ID: 'πŸ”, πŸ“„, πŸ’Ύ', COMMONSENSE_CATEGORY: 'πŸ€”, πŸ“š, πŸ’‘', SELECTED_CAPTION: 'πŸ“, πŸ‘Œ, πŸ’¬'}
target_size = (1024, 1024)


def func(index):
    example = whoops[index]
    values = get_instance_values(example)
    return values


def get_instance_values(example):
    values = []
    for k in left_side_columns + right_side_columns:
        if k in enumerate_cols:
            value = list_to_string(example[k])
        elif k == QA:
            qa_list = [f"Q: {x[0]} A: {x[1]}" for x in example[k]]
            value = list_to_string(qa_list)
        else:
            value = example[k]
        values.append(value)
    return values


def list_to_string(lst):
    return '\n'.join(['{}. {}'.format(i + 1, item) for i, item in enumerate(lst)])

def plot_image(index):
    example = whoops[index]
    instance_values = get_instance_values(example)
    assert len(left_side_columns) == len(
        instance_values[:len(left_side_columns)])  # excluding the image & designer
    for key, value in zip(left_side_columns, instance_values[:len(left_side_columns)]):
        if key == IMAGE:
            img = whoops[index]["image"]
            img_resized = img.resize(target_size)
            gr.Image(value=img_resized, label=whoops[index]['commonsense_category'])
        else:
            label = key.capitalize().replace("_", " ")
            gr.Textbox(value=value, label=f"{label} {emoji_to_label[key]}")
    with gr.Accordion("Click for details", open=False):
        assert len(right_side_columns) == len(
            instance_values[len(left_side_columns):])  # excluding the image & designer
        for key, value in zip(right_side_columns, instance_values[len(left_side_columns):]):
            label = key.capitalize().replace("_", " ")
            gr.Textbox(value=value, label=f"{label} {emoji_to_label[key]}")


columns_number = 4
# rows_number = int(dataset_size / columns_number)
rows_number = 30
whoops_sample = whoops.shuffle().select(range(0, columns_number * rows_number))
index = 0

with gr.Blocks() as demo:
    gr.Markdown(f"# WHOOPS! Dataset Explorer")
    for row_num in range(0, rows_number):
        with gr.Row():
            for col_num in range(0, columns_number):
                with gr.Column():
                    plot_image(index)
                    index += 1
demo.launch()