shuanholmes commited on
Commit
bf00c4c
1 Parent(s): d429710

[FireFlow] Init Commit

Browse files
Files changed (3) hide show
  1. app.py +84 -74
  2. flux/modules/layers.py +38 -12
  3. flux/sampling.py +19 -15
app.py CHANGED
@@ -45,24 +45,26 @@ def encode(init_image, torch_device):
45
  init_image = ae.encode(init_image.to()).to(torch.bfloat16)
46
  return init_image
47
 
48
-
 
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
50
  name = 'flux-dev'
51
- ae = load_ae(name, device)
52
  t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
53
  clip = load_clip(device)
54
- model = load_flow_model(name, device=device)
55
- offload = False
56
- name = "flux-dev"
 
 
57
  is_schnell = False
58
- feature_path = 'feature'
59
  output_dir = 'result'
60
  add_sampling_metadata = True
61
 
62
  @spaces.GPU(duration=120)
63
  @torch.inference_mode()
64
- def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
65
-
66
  device = "cuda" if torch.cuda.is_available() else "cpu"
67
  torch.cuda.empty_cache()
68
  seed = None
@@ -76,15 +78,12 @@ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guida
76
 
77
  width, height = init_image.shape[0], init_image.shape[1]
78
 
79
-
80
  init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
81
  init_image = init_image.unsqueeze(0)
82
  init_image = init_image.to(device)
83
  with torch.no_grad():
84
  init_image = ae.encode(init_image.to()).to(torch.bfloat16)
85
 
86
- print(init_image.shape)
87
-
88
  rng = torch.Generator(device="cpu")
89
  opts = SamplingOptions(
90
  source_prompt=source_prompt,
@@ -97,6 +96,11 @@ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guida
97
  )
98
  if opts.seed is None:
99
  opts.seed = torch.Generator(device="cpu").seed()
 
 
 
 
 
100
 
101
  print(f"Generating with seed {opts.seed}:\n{opts.source_prompt}")
102
  t0 = time.perf_counter()
@@ -106,12 +110,23 @@ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guida
106
  #############inverse#######################
107
  info = {}
108
  info['feature'] = {}
109
- info['inject_step'] = inject_step
 
 
 
 
 
 
110
 
111
  with torch.no_grad():
112
  inp = prepare(t5, clip, init_image, prompt=opts.source_prompt)
113
  inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt)
114
  timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
 
 
 
 
 
115
 
116
  # inversion initial noise
117
  with torch.no_grad():
@@ -137,6 +152,11 @@ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guida
137
  idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
138
  else:
139
  idx = 0
 
 
 
 
 
140
 
141
  device = torch.device("cuda")
142
  with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
@@ -166,97 +186,87 @@ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guida
166
  return img
167
 
168
 
169
-
170
- def create_demo(model_name: str, device: str = "cuda:0" if torch.cuda.is_available() else "cpu", offload: bool = False):
171
  is_schnell = model_name == "flux-schnell"
172
  title = r"""
173
- <h1 align="center">🪄 Taming Rectified Flow for Inversion and Editing</h1>
174
  """
175
-
176
  description = r"""
177
- <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/wangjiangshan0725/RF-Solver-Edit' target='_blank'><b>Taming Rectified Flow for Inversion and Editing</b></a>.<br>
178
-
179
- ❗️❗️❗️[<b>Important</b>] Editing steps:<br>
180
- 1️⃣ Upload images you want to edit (The resolution is expected be less than 1360*768, or the memory of GPU may be not enough.) <br>
181
- 2️⃣ Enter the source prompt, which describes the content of the image you unload. The source prompt is not mandatory; you can also leave it to null. <br>
182
- 3️⃣ Enter the target prompt which describes the expected content of the edited image. <br>
183
- 4️⃣ Click the <b>Generate</b> button to start editing. <br>
184
- 5️⃣ We suggest to adjust the value of **feature sharing steps** for better results.<br>
185
- """
186
- article = r"""
187
- If our work is helpful, please help to ⭐ the <a href='https://github.com/wangjiangshan0725/RF-Solver-Edit' target='_blank'>Github Repo</a>. Thanks!
188
  """
189
-
190
- badge = r"""
191
- [![GitHub Stars](https://img.shields.io/github/stars/wangjiangshan0725/RF-Solver-Edit?style=social)](https://github.com/wangjiangshan0725/RF-Solver-Edit)
192
  """
193
-
194
  css = '''
195
  .gradio-container {width: 85% !important}
196
  '''
197
  with gr.Blocks(css=css) as demo:
198
- # gr.Markdown(f"# Official Demo for Taming Rectified Flow for Inversion and Editing")
199
-
200
  gr.HTML(title)
201
  gr.Markdown(description)
202
  gr.Markdown(article)
203
- gr.Markdown(badge)
204
 
 
205
  with gr.Row():
 
206
  with gr.Column():
207
- source_prompt = gr.Textbox(label="Source Prompt", value="")
208
- target_prompt = gr.Textbox(label="Target Prompt", value="")
209
- # source_prompt = gr.Text(
210
- # label="Source Prompt",
211
- # show_label=False,
212
- # max_lines=1,
213
- # placeholder="Enter your source prompt",
214
- # container=False,
215
- # value=""
216
- # )
217
- # target_prompt = gr.Text(
218
- # label="Target Prompt",
219
- # show_label=False,
220
- # max_lines=1,
221
- # placeholder="Enter your target prompt",
222
- # container=False,
223
- # value=""
224
- # )
225
  init_image = gr.Image(label="Input Image", visible=True)
226
-
227
-
 
 
 
 
 
 
 
228
  generate_btn = gr.Button("Generate")
229
 
 
230
  with gr.Column():
231
  with gr.Accordion("Advanced Options", open=True):
232
- num_steps = gr.Slider(1, 30, 25, step=1, label="Total timesteps")
233
- inject_step = gr.Slider(1, 15, 3, step=1, label="Feature sharing steps")
234
- guidance = gr.Slider(1.0, 10.0, 2, step=0.1, label="Guidance", interactive=not is_schnell)
235
- # seed = gr.Textbox(0, label="Seed (-1 for random)", visible=False)
236
- # add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
 
238
  output_image = gr.Image(label="Generated Image")
239
 
 
240
  generate_btn.click(
241
  fn=edit,
242
- inputs=[init_image, source_prompt, target_prompt, num_steps, inject_step, guidance],
 
 
 
 
 
 
 
 
243
  outputs=[output_image]
244
  )
245
-
246
-
247
  return demo
248
 
249
-
250
- # if __name__ == "__main__":
251
- # import argparse
252
- # parser = argparse.ArgumentParser(description="Flux")
253
- # parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name")
254
- # parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
255
- # parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
256
- # parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
257
-
258
- # parser.add_argument("--port", type=int, default=41035)
259
- # args = parser.parse_args()
260
-
261
  demo = create_demo("flux-dev", "cuda")
262
  demo.launch()
 
45
  init_image = ae.encode(init_image.to()).to(torch.bfloat16)
46
  return init_image
47
 
48
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ offload = True
50
  device = "cuda" if torch.cuda.is_available() else "cpu"
51
  name = 'flux-dev'
52
+ ae = load_ae(name, device="cpu" if offload else torch_device)
53
  t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
54
  clip = load_clip(device)
55
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
56
+ if offload:
57
+ model.cpu()
58
+ torch.cuda.empty_cache()
59
+ ae.encoder.to(torch_device)
60
  is_schnell = False
 
61
  output_dir = 'result'
62
  add_sampling_metadata = True
63
 
64
  @spaces.GPU(duration=120)
65
  @torch.inference_mode()
66
+ def edit(init_image, source_prompt, target_prompt, editing_strategy, num_steps, inject_step, guidance, seed):
67
+ global ae, t5, clip, model, name, is_schnell, output_dir, add_sampling_metadata
68
  device = "cuda" if torch.cuda.is_available() else "cpu"
69
  torch.cuda.empty_cache()
70
  seed = None
 
78
 
79
  width, height = init_image.shape[0], init_image.shape[1]
80
 
 
81
  init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
82
  init_image = init_image.unsqueeze(0)
83
  init_image = init_image.to(device)
84
  with torch.no_grad():
85
  init_image = ae.encode(init_image.to()).to(torch.bfloat16)
86
 
 
 
87
  rng = torch.Generator(device="cpu")
88
  opts = SamplingOptions(
89
  source_prompt=source_prompt,
 
96
  )
97
  if opts.seed is None:
98
  opts.seed = torch.Generator(device="cpu").seed()
99
+
100
+ if offload:
101
+ ae = ae.cpu()
102
+ torch.cuda.empty_cache()
103
+ t5, clip = t5.to(torch_device), clip.to(torch_device)
104
 
105
  print(f"Generating with seed {opts.seed}:\n{opts.source_prompt}")
106
  t0 = time.perf_counter()
 
110
  #############inverse#######################
111
  info = {}
112
  info['feature'] = {}
113
+ info['inject_step'] = min(inject_step, num_steps)
114
+ info['reuse_v']= False
115
+ info['editing_strategy']= " ".join(editing_strategy)
116
+ info['start_layer_index'] = 20
117
+ info['end_layer_index'] = 37
118
+ qkv_ratio = '1.0,1.0,1.0'
119
+ info['qkv_ratio'] = list(map(float, qkv_ratio.split(',')))
120
 
121
  with torch.no_grad():
122
  inp = prepare(t5, clip, init_image, prompt=opts.source_prompt)
123
  inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt)
124
  timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
125
+
126
+ if offload:
127
+ t5, clip = t5.cpu(), clip.cpu()
128
+ torch.cuda.empty_cache()
129
+ model = model.to(torch_device)
130
 
131
  # inversion initial noise
132
  with torch.no_grad():
 
152
  idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
153
  else:
154
  idx = 0
155
+
156
+ if offload:
157
+ model.cpu()
158
+ torch.cuda.empty_cache()
159
+ ae.decoder.to(x.device)
160
 
161
  device = torch.device("cuda")
162
  with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
 
186
  return img
187
 
188
 
189
+ def create_demo(model_name: str, device: str = "cuda:0" if torch.cuda.is_available() else "cpu"):
 
190
  is_schnell = model_name == "flux-schnell"
191
  title = r"""
192
+ <h1 align="center">🔥FireFlow: Fast Inversion of Rectified Flow for Image Semantic Editing</h1>
193
  """
 
194
  description = r"""
195
+ <b>Official 🤗 Gradio Demo</b> for <a href='https://github.com/HolmesShuan/FireFlow-Fast-Inversion-of-Rectified-Flow-for-Image-Semantic-Editing' target='_blank'><b>🔥FireFlow: Fast Inversion of Rectified Flow for Image Semantic Editing</b></a>.<br>
 
 
 
 
 
 
 
 
 
 
196
  """
197
+ article = r"""
198
+ If you find our work helpful, we would greatly appreciate it if you could ⭐ our <a href='https://github.com/HolmesShuan/FireFlow-Fast-Inversion-of-Rectified-Flow-for-Image-Semantic-Editing' target='_blank'>GitHub repository</a>. Thank you for your support!
 
199
  """
 
200
  css = '''
201
  .gradio-container {width: 85% !important}
202
  '''
203
  with gr.Blocks(css=css) as demo:
204
+ # Add a title, description, and additional information
 
205
  gr.HTML(title)
206
  gr.Markdown(description)
207
  gr.Markdown(article)
 
208
 
209
+ # Layout: Two columns
210
  with gr.Row():
211
+ # Left Column: Inputs
212
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  init_image = gr.Image(label="Input Image", visible=True)
214
+ source_prompt = gr.Textbox(label="Source Prompt", value="", placeholder="(Optional) Describe the content of the uploaded image.")
215
+ target_prompt = gr.Textbox(label="Target Prompt", value="", placeholder="(Required) Describe the desired content of the edited image.")
216
+ # CheckboxGroup for editing strategies
217
+ editing_strategy = gr.CheckboxGroup(
218
+ label="Editing Technique",
219
+ choices=['replace_v', 'add_q', 'add_k'],
220
+ value=['replace_v'], # Default: none selected
221
+ interactive=True
222
+ )
223
  generate_btn = gr.Button("Generate")
224
 
225
+ # Right Column: Advanced options and output
226
  with gr.Column():
227
  with gr.Accordion("Advanced Options", open=True):
228
+ num_steps = gr.Slider(
229
+ minimum=1,
230
+ maximum=30,
231
+ value=8,
232
+ step=1,
233
+ label="Total timesteps"
234
+ )
235
+ inject_step = gr.Slider(
236
+ minimum=1,
237
+ maximum=15,
238
+ value=1,
239
+ step=1,
240
+ label="Feature sharing steps"
241
+ )
242
+ guidance = gr.Slider(
243
+ minimum=1.0,
244
+ maximum=8.0,
245
+ value=2.0,
246
+ step=0.1,
247
+ label="Guidance",
248
+ interactive=not is_schnell
249
+ )
250
 
251
+ # Output display
252
  output_image = gr.Image(label="Generated Image")
253
 
254
+ # Button click event to trigger the edit function
255
  generate_btn.click(
256
  fn=edit,
257
+ inputs=[
258
+ init_image,
259
+ source_prompt,
260
+ target_prompt,
261
+ editing_strategy, # Include the selected editing strategies
262
+ num_steps,
263
+ inject_step,
264
+ guidance
265
+ ],
266
  outputs=[output_image]
267
  )
268
+
 
269
  return demo
270
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  demo = create_demo("flux-dev", "cuda")
272
  demo.launch()
flux/modules/layers.py CHANGED
@@ -243,21 +243,47 @@ class SingleStreamBlock(nn.Module):
243
  q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
244
  q, k = self.norm(q, k, v)
245
 
246
- # Note: If the memory of your device is not enough, you may consider uncomment the following code.
247
- # if info['inject'] and info['id'] > 19:
248
- # store_path = os.path.join(info['feature_path'], str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'V' + '.pth')
249
- # if info['inverse']:
250
- # torch.save(v, store_path)
251
- # if not info['inverse']:
252
- # v = torch.load(store_path, weights_only=True)
253
-
254
  # Save the features in the memory
255
- if info['inject'] and info['id'] > 19:
256
- feature_name = str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'V'
 
 
257
  if info['inverse']:
258
- info['feature'][feature_name] = v.cpu()
 
 
 
 
 
 
 
 
 
 
259
  else:
260
- v = info['feature'][feature_name].cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  # compute attention
263
  attn = attention(q, k, v, pe=pe)
 
243
  q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
244
  q, k = self.norm(q, k, v)
245
 
 
 
 
 
 
 
 
 
246
  # Save the features in the memory
247
+ if info['inject'] and info['id'] <= info['end_layer_index'] and info['id'] >= info['start_layer_index']:
248
+ v_feature_name = str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'V'
249
+ k_feature_name = str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'K'
250
+ q_feature_name = str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'Q'
251
  if info['inverse']:
252
+ if info['reuse_v']:
253
+ info['feature'][v_feature_name] = v.cpu()
254
+ else:
255
+ editing_strategy = info['editing_strategy']
256
+ qkv_ratio = info['qkv_ratio']
257
+ if 'q' in editing_strategy:
258
+ info['feature'][q_feature_name] = (q * qkv_ratio[0]).cpu()
259
+ if 'k' in editing_strategy:
260
+ info['feature'][k_feature_name] = (k * qkv_ratio[1]).cpu()
261
+ if 'v' in editing_strategy:
262
+ info['feature'][v_feature_name] = (v * qkv_ratio[2]).cpu()
263
  else:
264
+ if info['reuse_v']:
265
+ if v_feature_name in info['feature']:
266
+ v = info['feature'][v_feature_name].cuda()
267
+ else:
268
+ editing_strategy = info['editing_strategy']
269
+ if 'replace_v' in editing_strategy:
270
+ if v_feature_name in info['feature']:
271
+ v = info['feature'][v_feature_name].cuda()
272
+ if 'add_v' in editing_strategy:
273
+ if v_feature_name in info['feature']:
274
+ v += info['feature'][v_feature_name].cuda()
275
+ if 'replace_k' in editing_strategy:
276
+ if k_feature_name in info['feature']:
277
+ k = info['feature'][k_feature_name].cuda()
278
+ if 'add_k' in editing_strategy:
279
+ if k_feature_name in info['feature']:
280
+ k += info['feature'][k_feature_name].cuda()
281
+ if 'replace_q' in editing_strategy:
282
+ if q_feature_name in info['feature']:
283
+ q = info['feature'][q_feature_name].cuda()
284
+ if 'add_q' in editing_strategy:
285
+ if q_feature_name in info['feature']:
286
+ q += info['feature'][q_feature_name].cuda()
287
 
288
  # compute attention
289
  attn = attention(q, k, v, pe=pe)
flux/sampling.py CHANGED
@@ -97,6 +97,7 @@ def denoise(
97
  guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
98
 
99
  step_list = []
 
100
  for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
101
  t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
102
  info['t'] = t_prev if inverse else t_curr
@@ -104,20 +105,23 @@ def denoise(
104
  info['second_order'] = False
105
  info['inject'] = inject_list[i]
106
 
107
- pred, info = model(
108
- img=img,
109
- img_ids=img_ids,
110
- txt=txt,
111
- txt_ids=txt_ids,
112
- y=vec,
113
- timesteps=t_vec,
114
- guidance=guidance_vec,
115
- info=info
116
- )
117
-
 
 
 
118
  img_mid = img + (t_prev - t_curr) / 2 * pred
119
 
120
- t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
121
  info['second_order'] = True
122
  pred_mid, info = model(
123
  img=img_mid,
@@ -129,9 +133,9 @@ def denoise(
129
  guidance=guidance_vec,
130
  info=info
131
  )
132
-
133
- first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2)
134
- img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order
135
 
136
  return img, info
137
 
 
97
  guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
98
 
99
  step_list = []
100
+ next_step_velocity = None
101
  for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
102
  t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
103
  info['t'] = t_prev if inverse else t_curr
 
105
  info['second_order'] = False
106
  info['inject'] = inject_list[i]
107
 
108
+ if next_step_velocity is None:
109
+ pred, info = model(
110
+ img=img,
111
+ img_ids=img_ids,
112
+ txt=txt,
113
+ txt_ids=txt_ids,
114
+ y=vec,
115
+ timesteps=t_vec,
116
+ guidance=guidance_vec,
117
+ info=info
118
+ )
119
+ else:
120
+ pred = next_step_velocity
121
+
122
  img_mid = img + (t_prev - t_curr) / 2 * pred
123
 
124
+ t_vec_mid = torch.full((img.shape[0],), t_curr + (t_prev - t_curr) / 2, dtype=img.dtype, device=img.device)
125
  info['second_order'] = True
126
  pred_mid, info = model(
127
  img=img_mid,
 
133
  guidance=guidance_vec,
134
  info=info
135
  )
136
+ next_step_velocity = pred_mid
137
+
138
+ img = img + (t_prev - t_curr) * pred_mid
139
 
140
  return img, info
141