yonatanbitton commited on
Commit
1aecc62
β€’
1 Parent(s): 319b393
Files changed (2) hide show
  1. app.py +116 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from datasets import load_dataset
3
+ import gradio as gr
4
+ import os
5
+
6
+ # auth_token = os.environ.get("auth_token")
7
+ auth_token = os.environ.get("HF_TOKEN")
8
+ print(auth_token)
9
+ Visual_Riddles = load_dataset("nitzanguetta/Visual_Riddles", token=auth_token)['test']
10
+ # print(f"Loaded WHOOPS!, first example:")
11
+ # print(whoops[0])
12
+ print("HI")
13
+ dataset_size = len(Visual_Riddles)
14
+
15
+ IMAGE = 'Image'
16
+ QUESTION = 'Question'
17
+ ANSWER = "Answer"
18
+ CAPTION = "Image caption"
19
+ PROMPT = "Prompt"
20
+ MODEL_NAME = "Model name"
21
+ HINT = "Hint"
22
+ ATTRIBUTION = "Attribution"
23
+ DLI = "Difficulty Level Index"
24
+ CATEGORY = "Category"
25
+ DESIGNER = "Designer"
26
+
27
+
28
+
29
+ left_side_columns = [IMAGE]
30
+ right_side_columns = [x for x in Visual_Riddles.features.keys() if x not in left_side_columns]
31
+ # enumerate_cols = [CROWD_CAPTIONS, CROWD_EXPLANATIONS, CROWD_UNDERSPECIFIED_CAPTIONS]
32
+ emoji_to_label = {IMAGE: '🎨, πŸ§‘β€πŸŽ¨, πŸ’»', ANSWER: 'πŸ’‘, πŸ€”, πŸ§‘β€πŸŽ¨', QUESTION: '❓, πŸ€”, πŸ’‘', CATEGORY: 'πŸ€”, πŸ“š, πŸ’‘',
33
+ CAPTION: 'πŸ“, πŸ‘Œ, πŸ’¬', PROMPT: 'πŸ“, πŸ’»', MODEL_NAME: '🎨, πŸ’»', HINT:'πŸ€”, πŸ”',
34
+ ATTRIBUTION: 'πŸ”, πŸ“„', DLI:"🌑️, πŸ€”, 🎯", DESIGNER:"πŸ§‘β€πŸŽ¨"}
35
+ # batch_size = 16
36
+ batch_size = 8
37
+ target_size = (1024, 1024)
38
+
39
+
40
+ def func(index):
41
+ start_index = index * batch_size
42
+ end_index = start_index + batch_size
43
+ all_examples = [Visual_Riddles[index] for index in list(range(start_index, end_index))]
44
+ values_lst = []
45
+ for example_idx, example in enumerate(all_examples):
46
+ values = get_instance_values(example)
47
+ values_lst += values
48
+ return values_lst
49
+
50
+
51
+ def get_instance_values(example):
52
+ values = []
53
+ for k in left_side_columns + right_side_columns:
54
+ if k == IMAGE:
55
+ value = example["Image"].resize(target_size)
56
+ elif k in enumerate_cols:
57
+ value = list_to_string(example[k])
58
+ elif k == QA:
59
+ qa_list = [f"Q: {x[0]} A: {x[1]}" for x in example[k]]
60
+ value = list_to_string(qa_list)
61
+ else:
62
+ value = example[k]
63
+ values.append(value)
64
+ return values
65
+ def list_to_string(lst):
66
+ return '\n'.join(['{}. {}'.format(i+1, item) for i, item in enumerate(lst)])
67
+
68
+ demo = gr.Blocks()
69
+
70
+
71
+ def get_col(example):
72
+ instance_values = get_instance_values(example)
73
+ with gr.Column():
74
+ inputs_left = []
75
+ assert len(left_side_columns) == len(
76
+ instance_values[:len(left_side_columns)]) # excluding the image & designer
77
+ for key, value in zip(left_side_columns, instance_values[:len(left_side_columns)]):
78
+ if key == IMAGE:
79
+ img_resized = example["Image"].resize(target_size)
80
+ # input_k = gr.Image(value=img_resized, label=example['commonsense_category'])
81
+ input_k = gr.Image(value=img_resized)
82
+ else:
83
+ label = key.capitalize().replace("_", " ")
84
+ input_k = gr.Textbox(value=value, label=f"{label} {emoji_to_label[key]}")
85
+ inputs_left.append(input_k)
86
+ with gr.Accordion("Click for details", open=False):
87
+ text_inputs_right = []
88
+ assert len(right_side_columns) == len(
89
+ instance_values[len(left_side_columns):]) # excluding the image & designer
90
+ for key, value in zip(right_side_columns, instance_values[len(left_side_columns):]):
91
+ label = key.capitalize().replace("_", " ")
92
+ text_input_k = gr.Textbox(value=value, label=f"{label} {emoji_to_label[key]}")
93
+ text_inputs_right.append(text_input_k)
94
+ return inputs_left, text_inputs_right
95
+
96
+
97
+ with demo:
98
+ gr.Markdown("# Slide to iterate Visual Riddles")
99
+
100
+ with gr.Column():
101
+ num_batches = math.ceil(dataset_size / batch_size)
102
+ slider = gr.Slider(minimum=0, maximum=num_batches, step=1, label=f'Page (out of {num_batches})')
103
+ with gr.Row():
104
+ index = slider.value
105
+ start_index = 0 * batch_size
106
+ end_index = start_index + batch_size
107
+ all_examples = [Visual_Riddles[index] for index in list(range(start_index, end_index))]
108
+ all_inputs_left_right = []
109
+ for example_idx, example in enumerate(all_examples):
110
+ inputs_left, text_inputs_right = get_col(example)
111
+ inputs_left_right = inputs_left + text_inputs_right
112
+ all_inputs_left_right += inputs_left_right
113
+
114
+ slider.change(func, inputs=[slider], outputs=all_inputs_left_right)
115
+
116
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ datasets
2
+ gradio==3.21.0