ujin-song commited on
Commit
57d7bf6
1 Parent(s): 4fba92e

Update app.py -- first released version

Browse files
Files changed (1) hide show
  1. app.py +256 -93
app.py CHANGED
@@ -1,96 +1,248 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- from diffusers import DiffusionPipeline
5
  import torch
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
17
 
18
- MAX_SEED = np.iinfo(np.int32).max
19
- MAX_IMAGE_SIZE = 1024
20
 
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  if randomize_seed:
24
  seed = random.randint(0, MAX_SEED)
25
-
26
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- return image
39
 
40
- examples = [
41
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
- "An astronaut riding a green horse",
43
- "A delicious ceviche cheesecake slice",
 
44
  ]
45
 
 
 
 
46
  css="""
47
  #col-container {
48
  margin: 0 auto;
49
- max-width: 520px;
50
  }
51
  """
52
 
53
- if torch.cuda.is_available():
54
- power_device = "GPU"
55
- else:
56
- power_device = "CPU"
57
-
58
  with gr.Blocks(css=css) as demo:
59
 
60
  with gr.Column(elem_id="col-container"):
61
  gr.Markdown(f"""
62
- # Text-to-Image Gradio Template
63
  Currently running on {power_device}.
64
  """)
65
-
 
 
 
 
 
 
66
  with gr.Row():
67
-
68
- with gr.Column():
69
- prompt = gr.Text(
70
- label="Prompt",
71
- show_label=False,
72
- max_lines=1,
73
- placeholder="Enter your prompt",
74
- container=False,
75
- )
76
- prompt2 = gr.Text(
77
- label="Prompt2",
78
- show_label=False,
79
- max_lines=1,
80
- placeholder="Enter your prompt for right character",
81
- container=False,
82
- )
83
 
84
- run_button = gr.Button("Run", scale=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  result = gr.Image(label="Result", show_label=False)
87
 
88
  with gr.Accordion("Advanced Settings", open=False):
89
-
90
  negative_prompt = gr.Text(
91
- label="Negative prompt",
 
 
 
 
 
 
 
92
  max_lines=1,
93
- placeholder="Enter a negative prompt",
94
  visible=False,
95
  )
96
 
@@ -106,49 +258,60 @@ with gr.Blocks(css=css) as demo:
106
 
107
  with gr.Row():
108
 
109
- width = gr.Slider(
110
- label="Width",
111
- minimum=256,
112
- maximum=MAX_IMAGE_SIZE,
113
- step=32,
114
- value=512,
115
  )
116
 
117
- height = gr.Slider(
118
- label="Height",
119
- minimum=256,
120
- maximum=MAX_IMAGE_SIZE,
121
- step=32,
122
- value=512,
123
  )
124
 
125
- with gr.Row():
126
-
127
- guidance_scale = gr.Slider(
128
- label="Guidance scale",
129
- minimum=0.0,
130
- maximum=10.0,
131
- step=0.1,
132
- value=0.0,
133
- )
134
-
135
- num_inference_steps = gr.Slider(
136
- label="Number of inference steps",
137
- minimum=1,
138
- maximum=12,
139
- step=1,
140
- value=2,
141
- )
142
-
143
  gr.Examples(
144
- examples = examples,
 
145
  inputs = [prompt]
146
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  run_button.click(
149
- fn = infer,
150
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  outputs = [result]
152
  )
153
 
154
- demo.queue().launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
4
  import torch
5
 
6
+ import io, json
7
+ from PIL import Image
8
+ import os.path
9
+ from weight_fusion import compose_concepts
10
+ from regionally_controlable_sampling import sample_image, build_model, prepare_text
 
 
 
 
 
11
 
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ power_device = "GPU" if torch.cuda.is_available() else "CPU"
14
 
15
+ MAX_SEED = 100_000
16
 
17
+ def generate(region1_concept,
18
+ region2_concept,
19
+ prompt,
20
+ region1_prompt,
21
+ region2_prompt,
22
+ negative_prompt,
23
+ region_neg_prompt,
24
+ seed,
25
+ randomize_seed,
26
+ sketch_adaptor_weight,
27
+ keypose_adaptor_weight
28
+ ):
29
+
30
  if randomize_seed:
31
  seed = random.randint(0, MAX_SEED)
32
+
33
+ region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower()
34
+ pretrained_model = merge(region1_concept, region2_concept)
35
+
36
+ keypose_condition = 'multi-concept/pose_data/two_apart.png'
37
+ region1 = '[0, 0, 512, 290]'
38
+ region2 = '[0, 650, 512, 910]'
39
+
40
+ region1_prompt = f'[<{region1_concept}1> <{region1_concept}2>, {region1_prompt}]'
41
+ region2_prompt = f'[<{region2_concept}1> <{region2_concept}2>, {region2_prompt}]'
42
+ prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}"
43
+
44
+ result = infer(pretrained_model,
45
+ prompt,
46
+ prompt_rewrite,
47
+ negative_prompt,
48
+ seed,
49
+ keypose_condition,
50
+ keypose_adaptor_weight,
51
+ # sketch_condition,
52
+ # sketch_adaptor_weight,
53
+ )
54
 
55
+ return result
56
+
57
+ def merge(concept1, concept2):
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ c1, c2 = sorted([concept1, concept2])
60
+ assert c1!=c2
61
+ merge_name = c1+'_'+c2
62
+
63
+ save_path = f'experiments/multi-concept/{merge_name}'
64
+
65
+ if os.path.isdir(save_path):
66
+ print(f'{save_path} already exists. Collecting merged weights from existing weights...')
67
+
68
+ else:
69
+ os.makedirs(save_path)
70
+ json_path = os.path.join(save_path,'merge_config.json')
71
+ alpha = 1.8
72
+ data = [
73
+ {
74
+ "lora_path": f"experiments/single-concept/{c1}/models/edlora_model-latest.pth",
75
+ "unet_alpha": alpha,
76
+ "text_encoder_alpha": alpha,
77
+ "concept_name": f"<{c1}1> <{c1}2>"
78
+ },
79
+ {
80
+ "lora_path": f"experiments/single-concept/{c2}/models/edlora_model-latest.pth",
81
+ "unet_alpha": alpha,
82
+ "text_encoder_alpha": alpha,
83
+ "concept_name": f"<{c2}1> <{c2}2>"
84
+ }
85
+ ]
86
+ with io.open(json_path,'w',encoding='utf8') as outfile:
87
+ json.dump(data, outfile, indent = 4, ensure_ascii=False)
88
+
89
+ compose_concepts(
90
+ concept_cfg=json_path,
91
+ optimize_textenc_iters=500,
92
+ optimize_unet_iters=50,
93
+ pretrained_model_path="nitrosocke/mo-di-diffusion",
94
+ save_path=save_path,
95
+ suffix='base',
96
+ device=device,
97
+ )
98
+ print(f'Merged weight for {c1}+{c2} saved in {save_path}!\n\n')
99
+
100
+ modelbase_path = os.path.join(save_path,'combined_model_base')
101
+ assert os.path.isdir(modelbase_path)
102
+
103
+ # save_path = 'experiments/multi-concept/elsa_moana_weight18/combined_model_base'
104
+ return modelbase_path
105
+
106
+ def infer(pretrained_model,
107
+ prompt,
108
+ prompt_rewrite,
109
+ negative_prompt='',
110
+ seed=16141,
111
+ keypose_condition=None,
112
+ keypose_adaptor_weight=1.0,
113
+ sketch_condition=None,
114
+ sketch_adaptor_weight=0.0,
115
+ region_sketch_adaptor_weight='',
116
+ region_keypose_adaptor_weight=''
117
+ ):
118
+
119
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
120
+ pipe = build_model(pretrained_model, device)
121
+
122
+ if sketch_condition is not None and os.path.exists(sketch_condition):
123
+ sketch_condition = Image.open(sketch_condition).convert('L')
124
+ width_sketch, height_sketch = sketch_condition.size
125
+ print('use sketch condition')
126
+ else:
127
+ sketch_condition, width_sketch, height_sketch = None, 0, 0
128
+ print('skip sketch condition')
129
+
130
+ if keypose_condition is not None and os.path.exists(keypose_condition):
131
+ keypose_condition = Image.open(keypose_condition).convert('RGB')
132
+ width_pose, height_pose = keypose_condition.size
133
+ print('use pose condition')
134
+ else:
135
+ keypose_condition, width_pose, height_pose = None, 0, 0
136
+ print('skip pose condition')
137
+
138
+ if width_sketch != 0 and width_pose != 0:
139
+ assert width_sketch == width_pose and height_sketch == height_pose, 'conditions should be same size'
140
+ width, height = max(width_pose, width_sketch), max(height_pose, height_sketch)
141
+ kwargs = {
142
+ 'sketch_condition': sketch_condition,
143
+ 'keypose_condition': keypose_condition,
144
+ 'height': height,
145
+ 'width': width,
146
+ }
147
+
148
+ prompts = [prompt]
149
+ prompts_rewrite = [prompt_rewrite]
150
+ input_prompt = [prepare_text(p, p_w, height, width) for p, p_w in zip(prompts, prompts_rewrite)]
151
+ save_prompt = input_prompt[0][0]
152
+ print(save_prompt)
153
+
154
+ image = sample_image(
155
+ pipe,
156
+ input_prompt=input_prompt,
157
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
158
+ generator=torch.Generator(device).manual_seed(seed),
159
+ sketch_adaptor_weight=sketch_adaptor_weight,
160
+ region_sketch_adaptor_weight=region_sketch_adaptor_weight,
161
+ keypose_adaptor_weight=keypose_adaptor_weight,
162
+ region_keypose_adaptor_weight=region_keypose_adaptor_weight,
163
+ **kwargs)
164
 
165
+ return image[0]
166
 
167
+ examples_context = [
168
+ 'walking at Stanford university campus',
169
+ 'in a castle',
170
+ 'in the forest',
171
+ 'in front of Eiffel tower'
172
  ]
173
 
174
+ examples_region1 = ['wearing red hat, high resolution, best quality','bright smile, wearing pants, best quality']
175
+ examples_region2 = ['smilling, wearing blue shirt, high resolution, best quality']
176
+
177
  css="""
178
  #col-container {
179
  margin: 0 auto;
180
+ max-width: 600px;
181
  }
182
  """
183
 
 
 
 
 
 
184
  with gr.Blocks(css=css) as demo:
185
 
186
  with gr.Column(elem_id="col-container"):
187
  gr.Markdown(f"""
188
+ # Orthogonal Adaptation
189
  Currently running on {power_device}.
190
  """)
191
+ prompt = gr.Text(
192
+ label="ContextPrompt",
193
+ show_label=False,
194
+ max_lines=1,
195
+ placeholder="Enter your context prompt for overall image",
196
+ container=False,
197
+ )
198
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ region1_concept = gr.Dropdown(
201
+ ["Elsa", "Moana"],
202
+ label="Character 1",
203
+ info="Will add more characters later!"
204
+ )
205
+ region2_concept = gr.Dropdown(
206
+ ["Elsa", "Moana"],
207
+ label="Character 2",
208
+ info="Will add more characters later!"
209
+ )
210
+
211
+ with gr.Row():
212
+
213
+ region1_prompt = gr.Textbox(
214
+ label="Region1 Prompt",
215
+ show_label=False,
216
+ max_lines=2,
217
+ placeholder="Enter your prompt for character 1",
218
+ container=False,
219
+ )
220
+
221
+ region2_prompt = gr.Textbox(
222
+ label="Region2 Prompt",
223
+ show_label=False,
224
+ max_lines=2,
225
+ placeholder="Enter your prompt for character 2",
226
+ container=False,
227
+ )
228
+
229
+ run_button = gr.Button("Run", scale=1)
230
 
231
  result = gr.Image(label="Result", show_label=False)
232
 
233
  with gr.Accordion("Advanced Settings", open=False):
234
+
235
  negative_prompt = gr.Text(
236
+ label="Context Negative prompt",
237
+ max_lines=1,
238
+ value = 'saturated, cropped, worst quality, low quality',
239
+ visible=False,
240
+ )
241
+
242
+ region_neg_prompt = gr.Text(
243
+ label="Regional Negative prompt",
244
  max_lines=1,
245
+ value = 'shirtless, nudity, saturated, cropped, worst quality, low quality',
246
  visible=False,
247
  )
248
 
 
258
 
259
  with gr.Row():
260
 
261
+ sketch_adaptor_weight = gr.Slider(
262
+ label="Sketch Adapter Weight",
263
+ minimum = 0,
264
+ maximum = 1,
265
+ step=0.01,
266
+ value=0,
267
  )
268
 
269
+ keypose_adaptor_weight = gr.Slider(
270
+ label="Keypose Adapter Weight",
271
+ minimum = 0,
272
+ maximum = 1,
273
+ step= 0.01,
274
+ value=1.0,
275
  )
276
 
277
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  gr.Examples(
279
+ label = 'Context Prompt example',
280
+ examples = examples_context,
281
  inputs = [prompt]
282
+ )
283
+
284
+ with gr.Row():
285
+ gr.Examples(
286
+ label = 'Region1 Prompt example',
287
+ examples = examples_region1,
288
+ inputs = [region1_prompt]
289
+ )
290
+
291
+ gr.Examples(
292
+ label = 'Region2 Prompt example',
293
+ examples = [examples_region2],
294
+ inputs = [region2_prompt]
295
+ )
296
+
297
 
298
  run_button.click(
299
+ fn = generate,
300
+ inputs = [region1_concept,
301
+ region2_concept,
302
+ prompt,
303
+ region1_prompt,
304
+ region2_prompt,
305
+ negative_prompt,
306
+ region_neg_prompt,
307
+ seed,
308
+ randomize_seed,
309
+ # sketch_condition,
310
+ # keypose_condition,
311
+ sketch_adaptor_weight,
312
+ keypose_adaptor_weight
313
+ ],
314
  outputs = [result]
315
  )
316
 
317
+ demo.queue().launch(share=True)