File size: 4,713 Bytes
9c00f5c
 
 
 
8be0786
 
9c00f5c
 
55cf602
 
7cebd9b
55cf602
 
f4b5fc1
 
55cf602
 
80597e4
 
9c00f5c
80597e4
9c00f5c
 
8fad05b
 
 
 
 
9c00f5c
 
 
80597e4
5d7586b
80597e4
9c00f5c
 
 
8be0786
9c00f5c
 
 
 
 
 
 
5bd6412
9c00f5c
 
 
 
 
 
 
 
 
 
5bd6412
9c00f5c
 
 
 
 
 
 
 
80597e4
 
9c00f5c
 
 
40139b3
9c00f5c
38ca5e2
 
9c00f5c
8fad05b
 
 
 
 
8be0786
9c00f5c
 
 
 
 
 
 
 
8be0786
80597e4
 
 
9c00f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8be0786
80597e4
8be0786
 
9c00f5c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python

from __future__ import annotations

import gradio as gr

from model import AppModel

DESCRIPTION = '''# <a href="https://github.com/THUDM/CogView2">CogView2</a> (text2image)

This Spaces demo runs only one of the two stages the CogView2 codebase has, due to GPU hardware limitations, with that the outputs may not match the original codebase/paper
This application accepts English or Chinese as input.
In general, Chinese input produces better results than English input.
If you check the "Translate to Chinese" checkbox, the app will use the English to Chinese translation results with [this Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) as input.
But the translation model may mistranslate and the results could be poor.
So, it is also a good idea to input the translation results from other translation services.
'''
NOTES = '''
- This app is adapted from <a href="https://github.com/hysts/CogView2_demo">https://github.com/hysts/CogView2_demo</a>. It would be recommended to use the repo if you want to run the app yourself.
'''
FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=THUDM.CogView2" />'


def set_example_text(example: list) -> list[dict]:
    return [
        gr.Textbox.update(value=example[0]),
        gr.Dropdown.update(value=example[1]),
    ]


def main():
    only_first_stage = True
    max_inference_batch_size = 8
    model = AppModel(max_inference_batch_size, only_first_stage)

    with gr.Blocks(css='style.css') as demo:
        gr.Markdown(DESCRIPTION)

        with gr.Row():
            with gr.Column():
                with gr.Group():
                    text = gr.Textbox(label='Input Text')
                    translate = gr.Checkbox(label='Translate to Chinese',
                                            value=False)
                    style = gr.Dropdown(choices=[
                        'none',
                        'mainbody',
                        'photo',
                        'flat',
                        'comics',
                        'oil',
                        'sketch',
                        'isometric',
                        'chinese',
                        'watercolor',
                    ],
                                        value='mainbody',
                                        label='Style')
                    seed = gr.Slider(0,
                                     100000,
                                     step=1,
                                     value=1234,
                                     label='Seed')
                    only_first_stage = gr.Checkbox(
                        label='Only First Stage',
                        value=only_first_stage,
                        visible=not only_first_stage)
                    num_images = gr.Slider(1,
                                           16,
                                           step=1,
                                           value=4,
                                           label='Number of Images')
                    run_button = gr.Button('Run')

                    with open('samples.txt') as f:
                        samples = [
                            line.strip().split('\t') for line in f.readlines()
                        ]
                    examples = gr.Dataset(components=[text, style],
                                          samples=samples)

            with gr.Column():
                with gr.Group():
                    translated_text = gr.Textbox(label='Translated Text')
                    with gr.Tabs():
                        with gr.TabItem('Output (Grid View)'):
                            result_grid = gr.Image(show_label=False)
                        with gr.TabItem('Output (Gallery)'):
                            result_gallery = gr.Gallery(show_label=False)

        gr.Markdown(NOTES)
        gr.Markdown(FOOTER)

        run_button.click(fn=model.run_with_translation,
                         inputs=[
                             text,
                             translate,
                             style,
                             seed,
                             only_first_stage,
                             num_images,
                         ],
                         outputs=[
                             translated_text,
                             result_grid,
                             result_gallery,
                         ])
        examples.click(fn=set_example_text,
                       inputs=examples,
                       outputs=examples.components)

    demo.launch(enable_queue=True)


if __name__ == '__main__':
    main()