File size: 7,950 Bytes
e9e2aab
b5e8b97
 
e9e2aab
b5e8b97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9e2aab
b5e8b97
 
e9e2aab
b5e8b97
 
 
 
 
 
 
 
 
e9e2aab
b5e8b97
b45a032
b5e8b97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9e2aab
b5e8b97
 
 
 
 
e9e2aab
 
b5e8b97
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import gradio as gr
import legacy
import dnnlib
import numpy as np
import torch

from find_direction import find_direction

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with dnnlib.util.open_url("./pretrained/ffhq.pkl") as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device) 


DESCRIPTION = '''# <a href="https://github.com/catlab-team/stylemc"> StyleMC:</a> Multi-Channel Based Fast Text-Guided Image Generation and Manipulation
'''
FOOTER = 'This space is built by <a href = "https://github.com/catlab-team">Catlab Team</a>.'


def main():
    with gr.Blocks(css='style.css') as demo:
        gr.Markdown(DESCRIPTION)

        with gr.Box():
            gr.Markdown('''## Step 1 (Finding a global manipulation direction)
- Please enter the target **text prompt** and **identity loss weight** to find global manipulation direction:
- Hit the **Find Direction** button.
''')
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        text = gr.Textbox(
                            label="Enter your prompt",
                            show_label=False,
                            max_lines=1,
                            placeholder="Enter your prompt",
                        ).style(
                            container=False,
                        )
                        identity_loss_weight = gr.Slider(0.1,
                                                10,
                                                value=0.5,
                                                step=0.1,
                                                label='Identity Loss Weight',
                                                interactive=True)
                        btn = gr.Button("Find Direction").style(full_width=False)

        with gr.Box():
            gr.Markdown('''## Step 2 (Manipulation)
- Please upload an image for manipulation:
    - You can also select the **previous directions** and determine the **manipulation strength**.
- Hit the **Generate** button.
''')
            with gr.Row():
                identity_loss_weight = gr.Slider(0.1,
                                                  100,
                                                  value=50,
                                                  step=0.1,
                                                  label='Manipulation Strength',
                                                  interactive=True)
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        input_image = gr.Image(label='Input Image',
                                               type='filepath')
                    with gr.Row():
                        generate_button = gr.Button('Generate')
                with gr.Column():
                    with gr.Row():
                        generated_image = gr.Image(label='Generated Image',
                                                type='numpy',
                                                interactive=False)
                        



#         with gr.Box():
#             gr.Markdown('''## Step 2 (Select Style Image)
# - Select **Style Type**.
# - Select **Style Image Index** from the image table below.
# ''')
#             with gr.Row():
#                 with gr.Column():
#                     style_type = gr.Radio(model.style_types,
#                                           label='Style Type')
#                     text = get_style_image_markdown_text('cartoon')
#                     style_image = gr.Markdown(value=text)
#                     style_index = gr.Slider(0,
#                                             316,
#                                             value=26,
#                                             step=1,
#                                             label='Style Image Index')

#             with gr.Row():
#                 example_styles = gr.Dataset(
#                     components=[style_type, style_index],
#                     samples=[
#                         ['cartoon', 26],
#                         ['caricature', 65],
#                         ['arcane', 63],
#                         ['pixar', 80],
#                     ])

#         with gr.Box():
#             gr.Markdown('''## Step 3 (Generate Style Transferred Image)
# - Adjust **Structure Weight** and **Color Weight**.
#     - These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image.
# - Hit the **Generate** button.
# ''')
#             with gr.Row():
#                 with gr.Column():
#                     with gr.Row():
#                         structure_weight = gr.Slider(0,
#                                                      1,
#                                                      value=0.6,
#                                                      step=0.1,
#                                                      label='Structure Weight')
#                     with gr.Row():
#                         color_weight = gr.Slider(0,
#                                                  1,
#                                                  value=1,
#                                                  step=0.1,
#                                                  label='Color Weight')
#                     with gr.Row():
#                         structure_only = gr.Checkbox(label='Structure Only')
#                     with gr.Row():
#                         generate_button = gr.Button('Generate')

#                 with gr.Column():
#                     result = gr.Image(label='Result')

#             with gr.Row():
#                 example_weights = gr.Dataset(
#                     components=[structure_weight, color_weight],
#                     samples=[
#                         [0.6, 1.0],
#                         [0.3, 1.0],
#                         [0.0, 1.0],
#                         [1.0, 0.0],
#                     ])

        gr.Markdown(FOOTER)

        # preprocess_button.click(fn=model.detect_and_align_face,
        #                         inputs=input_image,
        #                         outputs=aligned_face)
        # aligned_face.change(fn=model.reconstruct_face,
        #                     inputs=aligned_face,
        #                     outputs=[
        #                         reconstructed_face,
        #                         instyle,
        #                     ])
        # style_type.change(fn=update_slider,
        #                   inputs=style_type,
        #                   outputs=style_index)
        # style_type.change(fn=update_style_image,
        #                   inputs=style_type,
        #                   outputs=style_image)
        # generate_button.click(fn=model.generate,
        #                       inputs=[
        #                           style_type,
        #                           style_index,
        #                           structure_weight,
        #                           color_weight,
        #                           structure_only,
        #                           instyle,
        #                       ],
        #                       outputs=result)
        # example_images.click(fn=set_example_image,
        #                      inputs=example_images,
        #                      outputs=example_images.components)
        # example_styles.click(fn=set_example_styles,
        #                      inputs=example_styles,
        #                      outputs=example_styles.components)
        # example_weights.click(fn=set_example_weights,
        #                       inputs=example_weights,
        #                       outputs=example_weights.components)

    demo.launch(
        # enable_queue=args.enable_queue,
        # server_port=args.port,
        # share=args.share,
    )


if __name__ == '__main__':
    main()