File size: 1,869 Bytes
19327c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53b7b42
19327c9
 
53b7b42
19327c9
 
 
 
 
 
 
 
 
ef365f5
53b7b42
19327c9
 
806eb00
 
 
 
 
 
19327c9
 
 
806eb00
 
19327c9
a8208b6
19327c9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#!/usr/bin/env python

from __future__ import annotations

import os
import pathlib
import gradio as gr

from prismer_model import Model


def create_demo():
    model = Model()
    with gr.Row():
        with gr.Column():
            image = gr.Image(label='Input', type='filepath')
            model_name = gr.Dropdown(label='Model', choices=['Prismer-Base', 'Prismer-Large'], value='Prismer-Base')
            question = gr.Text(label='Question')
            run_button = gr.Button('Run')
        with gr.Column(scale=1.5):
            answer = gr.Text(label='Model Prediction')
            with gr.Row():
                depth = gr.Image(label='Depth')
                edge = gr.Image(label='Edge')
                normals = gr.Image(label='Normals')
            with gr.Row():
                segmentation = gr.Image(label='Segmentation')
                object_detection = gr.Image(label='Object Detection')
                ocr = gr.Image(label='OCR Detection')

    inputs = [image, model_name, question]
    outputs = [answer, depth, edge, normals, segmentation, object_detection, ocr]

    paths = sorted(pathlib.Path('prismer/images').glob('*'))
    ex_questions = ['What is the man on the left doing?',
                    'What is this person doing?',
                    'How many cows in this image?',
                    'What is the type of animal in this image?',
                    'What toy is it?']
    examples = [[path.as_posix(), 'Prismer-Base', ex_questions[i]] for i, path in enumerate(paths)]
    gr.Examples(examples=examples,
                inputs=inputs,
                outputs=outputs,
                fn=model.run_vqa,
                cache_examples=os.getenv('SYSTEM') == 'spaces')

    run_button.click(fn=model.run_vqa, inputs=inputs, outputs=outputs)


if __name__ == '__main__':
    demo = create_demo()
    demo.queue().launch()