File size: 3,681 Bytes
eef0d11
 
 
 
 
 
 
 
 
 
 
a8ae40c
 
eef0d11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f0b955
eef0d11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ddf99
 
 
 
 
 
 
 
 
 
 
 
 
eef0d11
 
 
 
 
 
 
 
 
 
 
 
 
20ddf99
 
eef0d11
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import json
import requests

data_url = "http://opencompass.openxlab.space/utils/RiseVis/data.json"
data = json.loads(requests.get(data_url).text)
# Get model names from the first entry
model_names = list(data[0]['results'].keys())

HTML_HEAD = '<table class="center">'
HTML_TAIL = '</table>'
N_COL = 5
WIDTH = 100 // N_COL

def get_image_gallery(idx, models):
    assert isinstance(idx, str)
    item = [x for x in data if x['index'] == idx]
    assert len(item) == 1
    item = item[0]
    html = HTML_HEAD
    models = list(models)
    models.sort()
    num_models = len(models)
    for i in range((num_models - 1) // N_COL + 1):
        sub_models = models[N_COL * i: N_COL * (i + 1)]
        html += '<tr>'
        for j in range(N_COL):
            if j >= len(sub_models):
                html += f'<td width={WIDTH}% style="text-align:center;"></td>'
            else:
                html += f'<td width={WIDTH}% style="text-align:center;"><h3>{sub_models[j]}</h3></td>'
        html += '</tr><tr>'
        for j in range(N_COL):
            if j >= len(sub_models):
                html += f'<td width={WIDTH}% style="text-align:center;"></td>'
            else:
                html += f'<td width={WIDTH}% style="text-align:center;"><img src="{URL_BASE + item["results"][sub_models[j]]}"></td>'
        html += '</tr>'
    html += HTML_TAIL
    return html

URL_BASE = 'https://opencompass.openxlab.space/utils/RiseVis/'

def get_origin_image(idx, model='original'):
    assert isinstance(idx, str)
    item = [x for x in data if x['index'] == idx]
    assert len(item) == 1
    item = item[0]
    file_name = item['image'] if model == 'original' else item['results']['model']
    url = URL_BASE + file_name
    return url

def read_instruction(idx):
    assert isinstance(idx, str)
    item = [x for x in data if x['index'] == idx]
    assert len(item) == 1
    return item[0]['instruction']

def on_prev(state):
    for i, item in enumerate(data):
        if item['index'] == state:
            break
    return data[i - 1]['index'], data[i - 1]['index']

def on_next(state):
    for i, item in enumerate(data):
        if item['index'] == state:
            break
    return data[i + 1]['index'], data[i + 1]['index']
    

with gr.Blocks() as demo:
    gr.Markdown("# Gallery of Generation Results on RISEBench")

    with gr.Row():
        with gr.Column(scale=2):
            with gr.Row():
                prev_button = gr.Button("PREV")
                next_button = gr.Button("NEXT")
                problem_index = gr.Textbox(value='causal_reasoning_1', label='Problem Index', interactive=True, visible=True)
                state = gr.Markdown(value='causal_reasoning_1', label='Current Problem Index', visible=False)   
                def update_state(problem_index):
                    return problem_index
                problem_index.submit(fn=update_state, inputs=[problem_index], outputs=[state])
                prev_button.click(fn=on_prev, inputs=[state], outputs=[state, problem_index])
                next_button.click(fn=on_next, inputs=[state], outputs=[state, problem_index])

            model_checkboxes = gr.CheckboxGroup(label="Select Models", choices=model_names, value=model_names)

        with gr.Column(scale=2):
            instruction = gr.Textbox(label="Instruction", interactive=False, value=read_instruction, inputs=[state])
        
        with gr.Column(scale=1):
            image = gr.Image(label="Input Image", value=get_origin_image, inputs=[state])

    gallery = gr.HTML(value=get_image_gallery, inputs=[state, model_checkboxes])

if __name__ == "__main__":
    demo.launch(server_name='0.0.0.0')