adirik commited on
Commit
b5e8b97
1 Parent(s): ff50bb1

update app

Browse files
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. __init__.py +0 -0
  3. app.py +177 -20
  4. find_direction.py +1 -4
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
__init__.py ADDED
File without changes
app.py CHANGED
@@ -1,29 +1,186 @@
1
- import cv2
2
- import torch
3
- import clip
4
  import gradio as gr
 
 
5
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
 
7
 
8
- # Use GPU if available
9
- if torch.cuda.is_available():
10
- device = torch.device("cuda")
11
- else:
12
- device = torch.device("cpu")
 
 
 
 
13
 
 
14
 
15
- def manipulate_image(img, text_queries, score_threshold):
16
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
18
 
19
- description = """
20
- """
21
 
22
- demo = gr.Interface(
23
- manipulate_image,
24
- inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)],
25
- outputs="image",
26
- title="Text-guided image manipulation with StyleMC",
27
- description=description,
28
- )
29
- demo.launch()
 
 
 
 
1
  import gradio as gr
2
+ import legacy
3
+ import dnnlib
4
  import numpy as np
5
+ import torch
6
+
7
+ from find_direction import find_direction
8
+
9
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
10
+ with dnnlib.util.open_url("./pretrained/ffhq.pkl") as f:
11
+ G = legacy.load_network_pkl(f)['G_ema'].to(device)
12
+
13
+
14
+ DESCRIPTION = '''# <a href="https://github.com/catlab-team/stylemc"> StyleMC:</a> Multi-Channel Based Fast Text-Guided Image Generation and Manipulation
15
+ '''
16
+ FOOTER = 'This space is built by <a href = "https://github.com/catlab-team">Catlab Team</a>.'
17
+
18
+
19
+ def main():
20
+ with gr.Blocks(css='style.css') as demo:
21
+ gr.Markdown(DESCRIPTION)
22
+
23
+ with gr.Box():
24
+ gr.Markdown('''## Step 1 (Finding a global manipulation direction)
25
+ - Please enter the target **text prompt** and **identity loss weight** to find global manipulation direction:
26
+ - Hit the **Find Direction** button.
27
+ ''')
28
+ with gr.Row():
29
+ with gr.Column():
30
+ with gr.Row():
31
+ text = gr.Textbox(
32
+ label="Enter your prompt",
33
+ show_label=False,
34
+ max_lines=1,
35
+ placeholder="Enter your prompt",
36
+ ).style(
37
+ container=False,
38
+ )
39
+ identity_loss_weight = gr.Slider(0.1,
40
+ 10,
41
+ value=0.5,
42
+ step=0.1,
43
+ label='Identity Loss Weight',
44
+ interactive=True)
45
+ btn = gr.Button("Find Direction").style(full_width=False)
46
+
47
+ with gr.Box():
48
+ gr.Markdown('''## Step 2 (Manipulation)
49
+ - Please upload an image for manipulation:
50
+ - You can also select the **previous directions** and determine the **manipulation strength**.
51
+ - Hit the **Generate** button.
52
+ ''')
53
+ with gr.Row():
54
+ identity_loss_weight = gr.Slider(0.1,
55
+ 100,
56
+ value=50,
57
+ step=0.1,
58
+ label='Manipulation Strength',
59
+ interactive=True)
60
+ with gr.Row():
61
+ with gr.Column():
62
+ with gr.Row():
63
+ input_image = gr.Image(label='Input Image',
64
+ type='filepath')
65
+ with gr.Row():
66
+ generate_button = gr.Button('Generate')
67
+ with gr.Column():
68
+ with gr.Row():
69
+ generated_image = gr.Image(label='Generated Image',
70
+ type='numpy',
71
+ interactive=False)
72
+
73
+
74
+
75
+
76
+ # with gr.Box():
77
+ # gr.Markdown('''## Step 2 (Select Style Image)
78
+ # - Select **Style Type**.
79
+ # - Select **Style Image Index** from the image table below.
80
+ # ''')
81
+ # with gr.Row():
82
+ # with gr.Column():
83
+ # style_type = gr.Radio(model.style_types,
84
+ # label='Style Type')
85
+ # text = get_style_image_markdown_text('cartoon')
86
+ # style_image = gr.Markdown(value=text)
87
+ # style_index = gr.Slider(0,
88
+ # 316,
89
+ # value=26,
90
+ # step=1,
91
+ # label='Style Image Index')
92
+
93
+ # with gr.Row():
94
+ # example_styles = gr.Dataset(
95
+ # components=[style_type, style_index],
96
+ # samples=[
97
+ # ['cartoon', 26],
98
+ # ['caricature', 65],
99
+ # ['arcane', 63],
100
+ # ['pixar', 80],
101
+ # ])
102
+
103
+ # with gr.Box():
104
+ # gr.Markdown('''## Step 3 (Generate Style Transferred Image)
105
+ # - Adjust **Structure Weight** and **Color Weight**.
106
+ # - These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image.
107
+ # - Hit the **Generate** button.
108
+ # ''')
109
+ # with gr.Row():
110
+ # with gr.Column():
111
+ # with gr.Row():
112
+ # structure_weight = gr.Slider(0,
113
+ # 1,
114
+ # value=0.6,
115
+ # step=0.1,
116
+ # label='Structure Weight')
117
+ # with gr.Row():
118
+ # color_weight = gr.Slider(0,
119
+ # 1,
120
+ # value=1,
121
+ # step=0.1,
122
+ # label='Color Weight')
123
+ # with gr.Row():
124
+ # structure_only = gr.Checkbox(label='Structure Only')
125
+ # with gr.Row():
126
+ # generate_button = gr.Button('Generate')
127
 
128
+ # with gr.Column():
129
+ # result = gr.Image(label='Result')
130
 
131
+ # with gr.Row():
132
+ # example_weights = gr.Dataset(
133
+ # components=[structure_weight, color_weight],
134
+ # samples=[
135
+ # [0.6, 1.0],
136
+ # [0.3, 1.0],
137
+ # [0.0, 1.0],
138
+ # [1.0, 0.0],
139
+ # ])
140
 
141
+ gr.Markdown(FOOTER)
142
 
143
+ # preprocess_button.click(fn=model.detect_and_align_face,
144
+ # inputs=input_image,
145
+ # outputs=aligned_face)
146
+ # aligned_face.change(fn=model.reconstruct_face,
147
+ # inputs=aligned_face,
148
+ # outputs=[
149
+ # reconstructed_face,
150
+ # instyle,
151
+ # ])
152
+ # style_type.change(fn=update_slider,
153
+ # inputs=style_type,
154
+ # outputs=style_index)
155
+ # style_type.change(fn=update_style_image,
156
+ # inputs=style_type,
157
+ # outputs=style_image)
158
+ # generate_button.click(fn=model.generate,
159
+ # inputs=[
160
+ # style_type,
161
+ # style_index,
162
+ # structure_weight,
163
+ # color_weight,
164
+ # structure_only,
165
+ # instyle,
166
+ # ],
167
+ # outputs=result)
168
+ # example_images.click(fn=set_example_image,
169
+ # inputs=example_images,
170
+ # outputs=example_images.components)
171
+ # example_styles.click(fn=set_example_styles,
172
+ # inputs=example_styles,
173
+ # outputs=example_styles.components)
174
+ # example_weights.click(fn=set_example_weights,
175
+ # inputs=example_weights,
176
+ # outputs=example_weights.components)
177
 
178
+ demo.launch(
179
+ # enable_queue=args.enable_queue,
180
+ # server_port=args.port,
181
+ # share=args.share,
182
+ )
183
 
 
 
184
 
185
+ if __name__ == '__main__':
186
+ main()
 
 
 
 
 
 
find_direction.py CHANGED
@@ -72,7 +72,7 @@ def unravel_index(index, shape):
72
  return tuple(reversed(out))
73
 
74
  def find_direction(
75
- network_pkl: str,
76
  text_prompt: str,
77
  truncation_psi: float = 0.7,
78
  noise_mode: str = "const",
@@ -82,10 +82,7 @@ def find_direction(
82
  seeds=np.random.randint(0, 1000, 128)
83
 
84
  batch_size=1
85
- print('Loading networks from "%s"...' % network_pkl)
86
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
87
- with dnnlib.util.open_url(network_pkl) as f:
88
- G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
89
 
90
  # Labels
91
  class_idx=None
 
72
  return tuple(reversed(out))
73
 
74
  def find_direction(
75
+ G,
76
  text_prompt: str,
77
  truncation_psi: float = 0.7,
78
  noise_mode: str = "const",
 
82
  seeds=np.random.randint(0, 1000, 128)
83
 
84
  batch_size=1
 
85
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
86
 
87
  # Labels
88
  class_idx=None