oyly commited on
Commit
87fa4fd
·
1 Parent(s): a8d4753

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .ipynb_checkpoints/
2
+ __pycache__/
3
+ *.pyc
app.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ import math
5
+ from dataclasses import dataclass
6
+ from glob import iglob
7
+ import argparse
8
+ from einops import rearrange
9
+ from PIL import ExifTags, Image
10
+ import torch
11
+ import gradio as gr
12
+ import numpy as np
13
+ import spaces
14
+ from huggingface_hub import login
15
+ login(token=os.getenv('Token'))
16
+ from flux.sampling_lore import denoise, get_schedule, prepare, unpack, get_v_mask, add_masked_noise_to_z,get_mask_one_tensor, denoise_with_noise_optim,prepare_tokens
17
+ from flux.util_lore import (configs, embed_watermark, load_ae, load_clip,
18
+ load_flow_model, load_t5)
19
+
20
+ def encode(init_image, torch_device, ae):
21
+ init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
22
+ init_image = init_image.unsqueeze(0)
23
+ init_image = init_image.to(torch_device)
24
+ init_image = ae.encode(init_image.to()).to(torch.bfloat16)
25
+ return init_image
26
+ from torchvision import transforms
27
+ transform = transforms.ToTensor()
28
+
29
+ class FluxEditor_lore_demo:
30
+ def __init__(self, model_name):
31
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ self.offload = False
33
+
34
+ self.name = model_name
35
+ self.is_schnell = model_name == "flux-schnell"
36
+ self.resize_longside = 800
37
+ self.save = False
38
+
39
+ self.output_dir = 'outputs_gradio'
40
+
41
+ self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
42
+ self.clip = load_clip(self.device)
43
+ self.model = load_flow_model(model_name, device=self.device)
44
+ self.ae = load_ae(self.name, device="cpu" if self.offload else self.device)
45
+
46
+ self.t5.eval()
47
+ self.clip.eval()
48
+ self.ae.eval()
49
+ self.info = {}
50
+ if self.offload:
51
+ self.model.cpu()
52
+ torch.cuda.empty_cache()
53
+ self.ae.encoder.to(self.device)
54
+ for param in self.model.parameters():
55
+ param.requires_grad = False # freeze the model
56
+ for param in self.t5.parameters():
57
+ param.requires_grad = False # freeze the model
58
+ for param in self.clip.parameters():
59
+ param.requires_grad = False # freeze the model
60
+ for param in self.ae.parameters():
61
+ param.requires_grad = False # freeze the model
62
+
63
+ def resize_image(self,image):
64
+ pil_image = Image.fromarray(image)
65
+ h, w = pil_image.size[1], pil_image.size[0]
66
+ if h <= self.resize_longside and w <= self.resize_longside:
67
+ return image
68
+
69
+ if h >= w:
70
+ new_h = self.resize_longside
71
+ new_w = int(w * self.resize_longside / h)
72
+ else:
73
+ new_w = self.resize_longside
74
+ new_h = int(h * self.resize_longside / w)
75
+
76
+ resized_image = pil_image.resize((new_w, new_h), Image.LANCZOS)
77
+ return np.array(resized_image)
78
+
79
+ def resize_mask(self,mask,height,width):
80
+ pil_mask = Image.fromarray(mask.astype(np.uint8)) # ensure it's 8-bit for PIL
81
+ resized_pil = pil_mask.resize((width, height), Image.NEAREST) # width first!
82
+ return np.array(resized_pil)
83
+
84
+ @spaces.GPU(duration=240)
85
+ def inverse(self, brush_canvas,src_prompt,
86
+ inversion_num_steps, injection_num_steps,
87
+ inversion_guidance,
88
+ ):
89
+ print(f"Inversing {src_prompt}, guidance {inversion_guidance}, inje/step {injection_num_steps}/{inversion_num_steps}")
90
+ self.z0 = None
91
+ self.zt = None
92
+ torch.cuda.empty_cache()
93
+ if self.info:
94
+ del self.info
95
+ self.info = {'src_p':src_prompt}
96
+
97
+ rgba_init_image = brush_canvas["background"]
98
+ init_image = rgba_init_image[:,:,:3]
99
+
100
+
101
+ if self.resize_longside != -1:
102
+ init_image = self.resize_image(init_image)
103
+ shape = init_image.shape
104
+
105
+ new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
106
+ new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
107
+
108
+ init_image = init_image[:new_h, :new_w, :]
109
+ width, height = init_image.shape[0], init_image.shape[1]
110
+ self.init_image = encode(init_image, self.device, self.ae)
111
+
112
+ if self.save:
113
+ ori_output_path = os.path.join(self.output_dir,f'{src_prompt[:20]}_ori.png')
114
+ Image.fromarray(init_image,'RGB').save(ori_output_path)
115
+
116
+ t0 = time.perf_counter()
117
+
118
+ self.info['feature'] = {}
119
+ self.info['inject_step'] = injection_num_steps
120
+ self.info['wh'] = (width, height)
121
+
122
+ torch.cuda.empty_cache()
123
+
124
+ inp = prepare(self.t5, self.clip, self.init_image, prompt=src_prompt)
125
+ timesteps = get_schedule(inversion_num_steps, inp["img"].shape[1], shift=True)
126
+ self.info['x_ori'] = inp["img"].clone()
127
+
128
+ # inversion initial noise
129
+ torch.set_grad_enabled(False)
130
+ z, info, _, _ = denoise(self.model, **inp, timesteps=timesteps, guidance=inversion_guidance, inverse=True, info=self.info)
131
+ self.z0 = z
132
+ self.info = info
133
+
134
+ t1 = time.perf_counter()
135
+ print(f"inversion Done in {t1 - t0:.1f}s.")
136
+ return init_image
137
+
138
+ @spaces.GPU(duration=240)
139
+ def edit(self, brush_canvas, source_prompt, inversion_guidance,
140
+ target_prompt, target_object,target_object_index,
141
+ inversion_num_steps, injection_num_steps,
142
+ training_epochs,
143
+ denoise_guidance,noise_scale,seed,
144
+ ):
145
+
146
+ torch.cuda.empty_cache()
147
+ if 'src_p' not in self.info or self.info['src_p'] != source_prompt:
148
+ print('src prompt changed. inverse again')
149
+ self.inverse(brush_canvas,source_prompt,
150
+ inversion_num_steps, injection_num_steps,
151
+ inversion_guidance)
152
+
153
+ rgba_init_image = brush_canvas["background"]
154
+ rgba_mask = brush_canvas["layers"][0]
155
+ init_image = rgba_init_image[:,:,:3]
156
+ if self.resize_longside != -1:
157
+ init_image = self.resize_image(init_image)
158
+ width, height = self.info['wh']
159
+ init_image = init_image[:width, :height, :]
160
+ #rgba_init_image = rgba_init_image[:height, :width, :]
161
+
162
+ if self.resize_longside != -1:
163
+ mask = self.resize_mask(rgba_mask[:,:,3],height,width)
164
+ else:
165
+ mask = rgba_mask[:width, :height, 3]
166
+ mask = mask.astype(int)
167
+
168
+ rgba_mask[:,:,3] = rgba_mask[:,:,3]//2
169
+ masked_image = Image.alpha_composite(Image.fromarray(rgba_init_image, 'RGBA'), Image.fromarray(rgba_mask, 'RGBA'))
170
+ masked_image = masked_image.resize((height, width), Image.LANCZOS)
171
+
172
+
173
+ # prepare source mask and vmask
174
+ inp_optim = prepare(self.t5, self.clip, self.init_image, prompt=target_prompt)
175
+ inp_target = prepare(self.t5, self.clip, self.init_image, prompt=target_prompt)
176
+ v_mask,source_mask = self.get_v_src_masks(mask,width,height,self.device)
177
+ self.info['change_v'] = 2 # v_mask
178
+ self.info['v_mask'] = v_mask
179
+ self.info['source_mask'] = source_mask
180
+ self.info['inject_step'] = injection_num_steps
181
+ timesteps = get_schedule(inversion_num_steps, inp_optim["img"].shape[1], shift=True)
182
+ seed = int(seed)
183
+ if seed == -1:
184
+ seed = torch.randint(0, 2**32, (1,)).item()
185
+
186
+ # prepare token_ids
187
+ token_ids=[]
188
+ replacements = [[None,target_object,-1,int(target_object_index)]]
189
+ src_dif_ids,tgt_dif_ids = prepare_tokens(self.t5, source_prompt, target_prompt, replacements,True)
190
+ for t_ids in tgt_dif_ids:
191
+ token_ids.append([t_ids,True,1])
192
+ print('token_ids',token_ids)
193
+ # do latent optim
194
+
195
+ t0 = time.perf_counter()
196
+ print(f'optimizing & editing noise, {target_prompt} with seed {seed}, noise_scale {noise_scale}, training_epochs {training_epochs}')
197
+ if training_epochs != 0:
198
+ torch.set_grad_enabled(True)
199
+ inp_optim["img"] = self.z0
200
+ _, info, _, _, trainable_noise_list = denoise_with_noise_optim(self.model,**inp_optim,token_ids=token_ids,source_mask=source_mask,training_steps=1,training_epochs=training_epochs,learning_rate=0.01,seed=seed,noise_scale=noise_scale,timesteps=timesteps,info=self.info,guidance=denoise_guidance)
201
+ z_optim = trainable_noise_list[0]
202
+ self.info = info
203
+ else:
204
+ z_optim = add_masked_noise_to_z(self.z0,source_mask,width,height,seed=seed,noise_scale=noise_scale)
205
+ trainable_noise_list = None
206
+
207
+ # denoise (editing)
208
+ inp_target["img"] = z_optim
209
+ timesteps = get_schedule(inversion_num_steps, inp_target["img"].shape[1], shift=True)
210
+ self.model.eval()
211
+ torch.set_grad_enabled(False)
212
+ x, _, _, _ = denoise(self.model, **inp_target, timesteps=timesteps, guidance=denoise_guidance, inverse=False, info=self.info, trainable_noise_list = trainable_noise_list)
213
+
214
+ # decode latents to pixel space
215
+ batch_x = unpack(x.float(), width,height)
216
+
217
+ for x in batch_x:
218
+ x = x.unsqueeze(0)
219
+
220
+
221
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
222
+ x = self.ae.decode(x)
223
+
224
+ if torch.cuda.is_available():
225
+ torch.cuda.synchronize()
226
+ # bring into PIL format and save
227
+ x = x.clamp(-1, 1)
228
+ x = embed_watermark(x.float())
229
+ x = rearrange(x[0], "c h w -> h w c")
230
+
231
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
232
+ exif_data = Image.Exif()
233
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
234
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
235
+ if self.save:
236
+ output_path = os.path.join(self.output_dir,f'{target_object}_{injection_num_steps:02d}_{inversion_num_steps}_seed_{seed}_epoch_{training_epochs:03d}_scale_{noise_scale:.2f}.png')
237
+ img.save(output_path, exif=exif_data, quality=95, subsampling=0)
238
+ masked_image.save(output_path.replace(target_object,f'{target_object}_masked'))
239
+ binary_mask = np.where(mask != 0, 255, 0).astype(np.uint8)
240
+ Image.fromarray(binary_mask, mode="L").save(output_path.replace(target_object,f'{target_object}_mask'))
241
+ t1 = time.perf_counter()
242
+ print(f"Done in {t1 - t0:.1f}s.", f'Saving {output_path} .' if self.save else 'No saving files.')
243
+
244
+ return img
245
+
246
+ def encode(self,init_image, torch_device):
247
+ init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
248
+ init_image = init_image.unsqueeze(0)
249
+ init_image = init_image.to(torch_device)
250
+ self.ae.encoder.to(torch_device)
251
+
252
+ init_image = self.ae.encode(init_image).to(torch.bfloat16)
253
+ return init_image
254
+
255
+ def get_v_src_masks(self,mask,width,height,device,txt_length=512):
256
+ # resize mask to token size
257
+ mask = (mask > 127).astype(np.uint8)
258
+ mask = mask * 255
259
+ pil_mask = Image.fromarray(mask)
260
+ pil_mask = pil_mask.resize((math.ceil(height/16), math.ceil(width/16)), Image.Resampling.LANCZOS)
261
+
262
+ mask = transform(pil_mask)
263
+ mask = mask.flatten().to(device)
264
+
265
+ s_mask = mask.view(1, 1, -1, 1)
266
+ s_mask = s_mask.to(torch.bfloat16)
267
+ v_mask = torch.cat([torch.ones(txt_length).to(device),mask])
268
+ v_mask = v_mask.view(1, 1, -1, 1)
269
+ v_mask = v_mask.to(torch.bfloat16)
270
+ return v_mask,s_mask
271
+
272
+ def create_demo(model_name: str):
273
+ editor = FluxEditor_lore_demo(model_name)
274
+ is_schnell = model_name == "flux-schnell"
275
+
276
+ title = r"""
277
+ <h1 align="center">🎨 LORE Image Editing </h1>
278
+ """
279
+
280
+ description = r"""
281
+ <b>Official 🤗 Gradio demo</b> <br>
282
+ <b>LORE: Latent Optimization for Precise Semantic Control in Rectified Flow-based Image Editing.</b><br>
283
+ <b>Here are editing steps:</b> <br>
284
+ 1️⃣ Upload your source image. <br>
285
+ 2️⃣ Fill in your source prompt and click the "Inverse" button to perform image inversion. <br>
286
+ 3️⃣ Use the brush tool to draw your mask. (on layer 1) <br>
287
+ 4️⃣ Fill in your target prompt, then adjust the hyperparameters. <br>
288
+ 5️⃣ Click the "Edit" button to generate your edited image! <br>
289
+ 6️⃣ If source image and prompt are not changed, you can click 'Edit' for next generation. <br>
290
+
291
+ 🔔 [<b>Note</b>] Due to limited resources, we will resize image to <=800 longside. <br>
292
+ """
293
+ article = r"""
294
+ https://github.com/oyly16/LORE
295
+ """
296
+
297
+ with gr.Blocks() as demo:
298
+ gr.HTML(title)
299
+ gr.Markdown(description)
300
+
301
+ with gr.Row():
302
+ with gr.Column():
303
+ src_prompt = gr.Textbox(label="Source Prompt", value='' )
304
+ inversion_num_steps = gr.Slider(1, 50, 15, step=1, label="Number of inversion/denoise steps")
305
+ injection_num_steps = gr.Slider(1, 50, 12, step=1, label="Number of masked value injection steps")
306
+ target_prompt = gr.Textbox(label="Target Prompt", value='' )
307
+ target_object = gr.Textbox(label="Target Object", value='' )
308
+ target_object_index = gr.Textbox(label="Target Object Index (start index from 0 in target prompt)", value='' )
309
+ brush_canvas = gr.ImageEditor(label="Brush Canvas",
310
+ sources=('upload'),
311
+ brush=gr.Brush(colors=["#ff0000"],color_mode='fixed',default_color="#ff0000"),
312
+ interactive=True,
313
+ transforms=[],
314
+ container=True,
315
+ format='png',scale=1)
316
+
317
+ inv_btn = gr.Button("inverse")
318
+ edit_btn = gr.Button("edit")
319
+
320
+
321
+ with gr.Column():
322
+ with gr.Accordion("Advanced Options", open=True):
323
+
324
+ training_epochs = gr.Slider(0, 30, 10, step=1, label="Number of LORE training epochs")
325
+ inversion_guidance = gr.Slider(1.0, 10.0, 1.0, step=0.1, label="inversion Guidance", interactive=not is_schnell)
326
+ denoise_guidance = gr.Slider(1.0, 10.0, 2.0, step=0.1, label="denoise Guidance", interactive=not is_schnell)
327
+ noise_scale = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="renoise scale")
328
+ seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True)
329
+
330
+
331
+ output_image = gr.Image(label="Generated Image")
332
+ gr.Markdown(article)
333
+ inv_btn.click(
334
+ fn=editor.inverse,
335
+ inputs=[brush_canvas,src_prompt,
336
+ inversion_num_steps, injection_num_steps,
337
+ inversion_guidance,
338
+ ],
339
+ outputs=[output_image]
340
+ )
341
+ edit_btn.click(
342
+ fn=editor.edit,
343
+ inputs=[brush_canvas,src_prompt,inversion_guidance,
344
+ target_prompt, target_object,target_object_index,
345
+ inversion_num_steps, injection_num_steps,
346
+ training_epochs,
347
+ denoise_guidance,noise_scale,seed,
348
+ ],
349
+ outputs=[output_image]
350
+ )
351
+ gr.Examples(
352
+ examples=[
353
+ ["examples/woman.png", "a young woman", 15, 12, "a young woman with a necklace", "necklace", "5", 10, 0.9, "3"],
354
+ ["examples/car.png", "a taxi in a neon-lit street", 30, 24, "a race car in a neon-lit street", "race car", "1", 5, 0.1, "2388791121"],
355
+ ["examples/cup.png", "a cup on a wooden table", 10, 8, "a wooden table", "table", "2", 2, 0, "0"],
356
+ ],
357
+ inputs=[
358
+ brush_canvas,
359
+ src_prompt,
360
+ inversion_num_steps,
361
+ injection_num_steps,
362
+ target_prompt,
363
+ target_object,
364
+ target_object_index,
365
+ training_epochs,
366
+ noise_scale,
367
+ seed,
368
+ ],
369
+ label="Examples (Click to load)"
370
+ )
371
+
372
+ return demo
373
+
374
+ demo = create_demo("flux-dev")
375
+ demo.launch()
examples/car.png ADDED

Git LFS Details

  • SHA256: 040816a3187067fcb9f3a234ed1c1efe635ff3c2159523efb3b6d5af1c5f35aa
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
examples/car_mask.png ADDED

Git LFS Details

  • SHA256: 98ced88759e4c132e9c805dcf12b4989f8343a2fee638ec27f346f4ec5b73b9c
  • Pointer size: 129 Bytes
  • Size of remote file: 2.86 kB
examples/cup.png ADDED

Git LFS Details

  • SHA256: 3c14dddd2b1a09e7107a439116fd415d571e75bb74157e8ab9e9fd4653aea8ea
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
examples/cup_mask.png ADDED

Git LFS Details

  • SHA256: 55d6ab29c8632d8f293eeeedc01c98fe91a06c36e1f412bf16d0dc9762972c27
  • Pointer size: 129 Bytes
  • Size of remote file: 1.59 kB
examples/woman.png ADDED

Git LFS Details

  • SHA256: 38e91557b08ca8e78d22d3523c8bfa0a01082dd62c832a2c87d803e8cbfe5ef4
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
examples/woman_mask.png ADDED

Git LFS Details

  • SHA256: 16975f43f5a9b770b5cea2d97a82210fbafde7b85e0dfc6f374f4af40d6938c4
  • Pointer size: 129 Bytes
  • Size of remote file: 1.26 kB
flux/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from ._version import version as __version__ # type: ignore
3
+ from ._version import version_tuple
4
+ except ImportError:
5
+ __version__ = "unknown (no version information available)"
6
+ version_tuple = (0, 0, "unknown", "noinfo")
7
+
8
+ from pathlib import Path
9
+
10
+ PACKAGE = __package__.replace("_", "-")
11
+ PACKAGE_ROOT = Path(__file__).parent
flux/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
flux/_version.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
6
+ TYPE_CHECKING = False
7
+ if TYPE_CHECKING:
8
+ from typing import Tuple
9
+ from typing import Union
10
+
11
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
12
+ else:
13
+ VERSION_TUPLE = object
14
+
15
+ version: str
16
+ __version__: str
17
+ __version_tuple__: VERSION_TUPLE
18
+ version_tuple: VERSION_TUPLE
19
+
20
+ __version__ = version = '0.0.post61+g0274301.d20250318'
21
+ __version_tuple__ = version_tuple = (0, 0, 'g0274301.d20250318')
flux/api.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+ API_ENDPOINT = "https://api.bfl.ml"
10
+
11
+
12
+ class ApiException(Exception):
13
+ def __init__(self, status_code: int, detail: str | list[dict] | None = None):
14
+ super().__init__()
15
+ self.detail = detail
16
+ self.status_code = status_code
17
+
18
+ def __str__(self) -> str:
19
+ return self.__repr__()
20
+
21
+ def __repr__(self) -> str:
22
+ if self.detail is None:
23
+ message = None
24
+ elif isinstance(self.detail, str):
25
+ message = self.detail
26
+ else:
27
+ message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
28
+ return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
29
+
30
+
31
+ class ImageRequest:
32
+ def __init__(
33
+ self,
34
+ prompt: str,
35
+ width: int = 1024,
36
+ height: int = 1024,
37
+ name: str = "flux.1-pro",
38
+ num_steps: int = 50,
39
+ prompt_upsampling: bool = False,
40
+ seed: int | None = None,
41
+ validate: bool = True,
42
+ launch: bool = True,
43
+ api_key: str | None = None,
44
+ ):
45
+ """
46
+ Manages an image generation request to the API.
47
+
48
+ Args:
49
+ prompt: Prompt to sample
50
+ width: Width of the image in pixel
51
+ height: Height of the image in pixel
52
+ name: Name of the model
53
+ num_steps: Number of network evaluations
54
+ prompt_upsampling: Use prompt upsampling
55
+ seed: Fix the generation seed
56
+ validate: Run input validation
57
+ launch: Directly launches request
58
+ api_key: Your API key if not provided by the environment
59
+
60
+ Raises:
61
+ ValueError: For invalid input
62
+ ApiException: For errors raised from the API
63
+ """
64
+ if validate:
65
+ if name not in ["flux.1-pro"]:
66
+ raise ValueError(f"Invalid model {name}")
67
+ elif width % 32 != 0:
68
+ raise ValueError(f"width must be divisible by 32, got {width}")
69
+ elif not (256 <= width <= 1440):
70
+ raise ValueError(f"width must be between 256 and 1440, got {width}")
71
+ elif height % 32 != 0:
72
+ raise ValueError(f"height must be divisible by 32, got {height}")
73
+ elif not (256 <= height <= 1440):
74
+ raise ValueError(f"height must be between 256 and 1440, got {height}")
75
+ elif not (1 <= num_steps <= 50):
76
+ raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
77
+
78
+ self.request_json = {
79
+ "prompt": prompt,
80
+ "width": width,
81
+ "height": height,
82
+ "variant": name,
83
+ "steps": num_steps,
84
+ "prompt_upsampling": prompt_upsampling,
85
+ }
86
+ if seed is not None:
87
+ self.request_json["seed"] = seed
88
+
89
+ self.request_id: str | None = None
90
+ self.result: dict | None = None
91
+ self._image_bytes: bytes | None = None
92
+ self._url: str | None = None
93
+ if api_key is None:
94
+ self.api_key = os.environ.get("BFL_API_KEY")
95
+ else:
96
+ self.api_key = api_key
97
+
98
+ if launch:
99
+ self.request()
100
+
101
+ def request(self):
102
+ """
103
+ Request to generate the image.
104
+ """
105
+ if self.request_id is not None:
106
+ return
107
+ response = requests.post(
108
+ f"{API_ENDPOINT}/v1/image",
109
+ headers={
110
+ "accept": "application/json",
111
+ "x-key": self.api_key,
112
+ "Content-Type": "application/json",
113
+ },
114
+ json=self.request_json,
115
+ )
116
+ result = response.json()
117
+ if response.status_code != 200:
118
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
119
+ self.request_id = response.json()["id"]
120
+
121
+ def retrieve(self) -> dict:
122
+ """
123
+ Wait for the generation to finish and retrieve response.
124
+ """
125
+ if self.request_id is None:
126
+ self.request()
127
+ while self.result is None:
128
+ response = requests.get(
129
+ f"{API_ENDPOINT}/v1/get_result",
130
+ headers={
131
+ "accept": "application/json",
132
+ "x-key": self.api_key,
133
+ },
134
+ params={
135
+ "id": self.request_id,
136
+ },
137
+ )
138
+ result = response.json()
139
+ if "status" not in result:
140
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
141
+ elif result["status"] == "Ready":
142
+ self.result = result["result"]
143
+ elif result["status"] == "Pending":
144
+ time.sleep(0.5)
145
+ else:
146
+ raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
147
+ return self.result
148
+
149
+ @property
150
+ def bytes(self) -> bytes:
151
+ """
152
+ Generated image as bytes.
153
+ """
154
+ if self._image_bytes is None:
155
+ response = requests.get(self.url)
156
+ if response.status_code == 200:
157
+ self._image_bytes = response.content
158
+ else:
159
+ raise ApiException(status_code=response.status_code)
160
+ return self._image_bytes
161
+
162
+ @property
163
+ def url(self) -> str:
164
+ """
165
+ Public url to retrieve the image from
166
+ """
167
+ if self._url is None:
168
+ result = self.retrieve()
169
+ self._url = result["sample"]
170
+ return self._url
171
+
172
+ @property
173
+ def image(self) -> Image.Image:
174
+ """
175
+ Load the image as a PIL Image
176
+ """
177
+ return Image.open(io.BytesIO(self.bytes))
178
+
179
+ def save(self, path: str):
180
+ """
181
+ Save the generated image to a local path
182
+ """
183
+ suffix = Path(self.url).suffix
184
+ if not path.endswith(suffix):
185
+ path = path + suffix
186
+ Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
187
+ with open(path, "wb") as file:
188
+ file.write(self.bytes)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ from fire import Fire
193
+
194
+ Fire(ImageRequest)
flux/math.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+
5
+
6
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
+ q, k = apply_rope(q, k, pe)
8
+
9
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
10
+ x = rearrange(x, "B H L D -> B L (H D)")
11
+
12
+ return x
13
+
14
+ def attention_with_attnmap(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
15
+ q, k = apply_rope(q, k, pe)
16
+
17
+ x= torch.nn.functional.scaled_dot_product_attention(q, k, v)
18
+ x = rearrange(x, "B H L D -> B L (H D)")
19
+
20
+ # get attn map
21
+ d_k = q.shape[-1] # head_dim (D)
22
+ attn_map = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5) # [B, H, L, L]
23
+ return x, attn_map
24
+
25
+ def attention_with_attnmap_injection(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attnmap_idxs, old_attnmaps) -> Tensor:
26
+ q, k = apply_rope(q, k, pe)
27
+
28
+ # original attn
29
+ # x= torch.nn.functional.scaled_dot_product_attention(q, k, v)
30
+ # x = rearrange(x, "B H L D -> B L (H D)")
31
+
32
+ # get attn map
33
+ d_k = q.shape[-1] # head_dim (D)
34
+ attn_map = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5) # [B, H, L, L]
35
+ attn_map = torch.softmax(attn_map, dim=-1)
36
+ # inject attn map
37
+ for idx,old_attnmap in zip(attnmap_idxs,old_attnmaps):
38
+ attn_map[:,:,512:,idx] = old_attnmap
39
+ x = attn_map @ v
40
+ return x, attn_map
41
+
42
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
43
+ assert dim % 2 == 0
44
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
45
+ omega = 1.0 / (theta**scale)
46
+ out = torch.einsum("...n,d->...nd", pos, omega)
47
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
48
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
49
+ return out.float()
50
+
51
+
52
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
53
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
54
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
55
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
56
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
57
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
flux/model_lore.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from flux.modules.layers_lore import (DoubleStreamBlock, EmbedND, LastLayer,
7
+ MLPEmbedder, SingleStreamBlock,
8
+ timestep_embedding)
9
+
10
+
11
+ @dataclass
12
+ class FluxParams:
13
+ in_channels: int
14
+ vec_in_dim: int
15
+ context_in_dim: int
16
+ hidden_size: int
17
+ mlp_ratio: float
18
+ num_heads: int
19
+ depth: int
20
+ depth_single_blocks: int
21
+ axes_dim: list[int]
22
+ theta: int
23
+ qkv_bias: bool
24
+ guidance_embed: bool
25
+
26
+
27
+ class Flux(nn.Module):
28
+ """
29
+ Transformer model for flow matching on sequences.
30
+ """
31
+
32
+ def __init__(self, params: FluxParams):
33
+ super().__init__()
34
+
35
+ self.params = params
36
+ self.in_channels = params.in_channels
37
+ self.out_channels = self.in_channels
38
+ if params.hidden_size % params.num_heads != 0:
39
+ raise ValueError(
40
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
41
+ )
42
+ pe_dim = params.hidden_size // params.num_heads
43
+ if sum(params.axes_dim) != pe_dim:
44
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
45
+ self.hidden_size = params.hidden_size
46
+ self.num_heads = params.num_heads
47
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
48
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
49
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
50
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
51
+ self.guidance_in = (
52
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
53
+ )
54
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
55
+
56
+ self.double_blocks = nn.ModuleList(
57
+ [
58
+ DoubleStreamBlock(
59
+ self.hidden_size,
60
+ self.num_heads,
61
+ mlp_ratio=params.mlp_ratio,
62
+ qkv_bias=params.qkv_bias,
63
+ )
64
+ for _ in range(params.depth)
65
+ ]
66
+ )
67
+
68
+ self.single_blocks = nn.ModuleList(
69
+ [
70
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
71
+ for _ in range(params.depth_single_blocks)
72
+ ]
73
+ )
74
+
75
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
76
+
77
+ def forward(
78
+ self,
79
+ img: Tensor,
80
+ img_ids: Tensor,
81
+ txt: Tensor,
82
+ txt_ids: Tensor,
83
+ timesteps: Tensor,
84
+ y: Tensor,
85
+ guidance: Tensor | None = None,
86
+ info = None,
87
+ ) -> Tensor:
88
+ if img.ndim != 3 or txt.ndim != 3:
89
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
90
+
91
+ # running on sequences img
92
+ img = self.img_in(img)
93
+ vec = self.time_in(timestep_embedding(timesteps, 256))
94
+ if self.params.guidance_embed:
95
+ if guidance is None:
96
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
97
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
98
+ vec = vec + self.vector_in(y)
99
+ txt = self.txt_in(txt)
100
+
101
+
102
+ ids = torch.cat((txt_ids, img_ids), dim=1)
103
+ pe = self.pe_embedder(ids)
104
+
105
+ attn_maps = []
106
+
107
+ for block in self.double_blocks:
108
+ img, txt, attn_map = block(img=img, txt=txt, vec=vec, pe=pe, info=info)
109
+ attn_maps.append(attn_map)
110
+
111
+
112
+ cnt = 0
113
+ img = torch.cat((txt, img), 1)
114
+ info['type'] = 'single'
115
+ for block in self.single_blocks:
116
+ info['id'] = cnt
117
+ img, info, attn_map = block(img, vec=vec, pe=pe, info=info)
118
+ attn_maps.append(attn_map)
119
+ cnt += 1
120
+ attn_maps = torch.stack(attn_maps)
121
+ img = img[:, txt.shape[1] :, ...]
122
+
123
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) # 1, N, 64
124
+ return img, info, attn_maps
flux/modules/autoencoder.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # z to block_in
239
+ h = self.conv_in(z)
240
+
241
+ # middle
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.attn_1(h)
244
+ h = self.mid.block_2(h)
245
+
246
+ # upsampling
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ for i_block in range(self.num_res_blocks + 1):
249
+ h = self.up[i_level].block[i_block](h)
250
+ if len(self.up[i_level].attn) > 0:
251
+ h = self.up[i_level].attn[i_block](h)
252
+ if i_level != 0:
253
+ h = self.up[i_level].upsample(h)
254
+
255
+ # end
256
+ h = self.norm_out(h)
257
+ h = swish(h)
258
+ h = self.conv_out(h)
259
+ return h
260
+
261
+
262
+ class DiagonalGaussian(nn.Module):
263
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
+ super().__init__()
265
+ self.sample = sample
266
+ self.chunk_dim = chunk_dim
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
+ # import pdb;pdb.set_trace()
271
+ if self.sample:
272
+ std = torch.exp(0.5 * logvar)
273
+ return mean #+ std * torch.randn_like(mean)
274
+ else:
275
+ return mean
276
+
277
+
278
+ class AutoEncoder(nn.Module):
279
+ def __init__(self, params: AutoEncoderParams):
280
+ super().__init__()
281
+ self.encoder = Encoder(
282
+ resolution=params.resolution,
283
+ in_channels=params.in_channels,
284
+ ch=params.ch,
285
+ ch_mult=params.ch_mult,
286
+ num_res_blocks=params.num_res_blocks,
287
+ z_channels=params.z_channels,
288
+ )
289
+ self.decoder = Decoder(
290
+ resolution=params.resolution,
291
+ in_channels=params.in_channels,
292
+ ch=params.ch,
293
+ out_ch=params.out_ch,
294
+ ch_mult=params.ch_mult,
295
+ num_res_blocks=params.num_res_blocks,
296
+ z_channels=params.z_channels,
297
+ )
298
+ self.reg = DiagonalGaussian()
299
+
300
+ self.scale_factor = params.scale_factor
301
+ self.shift_factor = params.shift_factor
302
+
303
+ def encode(self, x: Tensor) -> Tensor:
304
+ z = self.reg(self.encoder(x))
305
+ z = self.scale_factor * (z - self.shift_factor)
306
+ return z
307
+
308
+ def decode(self, z: Tensor) -> Tensor:
309
+ z = z / self.scale_factor + self.shift_factor
310
+ return self.decoder(z)
311
+
312
+ def forward(self, x: Tensor) -> Tensor:
313
+ return self.decode(self.encode(x))
flux/modules/conditioner_lore.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
3
+ T5Tokenizer)
4
+
5
+
6
+ class HFEmbedder(nn.Module):
7
+ def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs):
8
+ super().__init__()
9
+ self.is_clip = is_clip
10
+ self.max_length = max_length
11
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
12
+
13
+ if self.is_clip:
14
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
15
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
16
+ else:
17
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
18
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
19
+
20
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
21
+
22
+ def forward(self, text: list[str]) -> Tensor:
23
+ batch_encoding = self.tokenizer(
24
+ text,
25
+ truncation=True,
26
+ max_length=self.max_length,
27
+ return_length=False,
28
+ return_overflowing_tokens=False,
29
+ padding="max_length",
30
+ return_tensors="pt",
31
+ )
32
+ if not self.is_clip:
33
+ pass
34
+
35
+ outputs = self.hf_module(
36
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
37
+ attention_mask=None,
38
+ output_hidden_states=False,
39
+ )
40
+ return outputs[self.output_key]
41
+
42
+ def forward_length(self, text: list[str]) -> Tensor:
43
+ batch_encoding = self.tokenizer(
44
+ text,
45
+ truncation=True,
46
+ max_length=self.max_length,
47
+ return_length=True,
48
+ return_overflowing_tokens=False,
49
+ padding="max_length",
50
+ return_tensors="pt",
51
+ )
52
+ if not self.is_clip:
53
+ pass
54
+
55
+ outputs = self.hf_module(
56
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
57
+ attention_mask=None,
58
+ output_hidden_states=False,
59
+ )
60
+ # -1 to delete the end token
61
+ return outputs[self.output_key],batch_encoding['length']-1
62
+
63
+ def get_word_embed(self, text: list[str]) -> Tensor:
64
+ batch_encoding = self.tokenizer(
65
+ text,
66
+ truncation=True,
67
+ max_length=16,
68
+ return_length=True,
69
+ return_overflowing_tokens=False,
70
+ padding="max_length",
71
+ return_tensors="pt",
72
+ )
73
+
74
+ input_ids = batch_encoding["input_ids"].to(self.hf_module.device)
75
+ attention_mask = batch_encoding["attention_mask"].to(self.hf_module.device)
76
+
77
+ outputs = self.hf_module(
78
+ input_ids=input_ids,
79
+ attention_mask=attention_mask,
80
+ output_hidden_states=False,
81
+ )
82
+
83
+ token_embeddings = outputs[self.output_key] # [B, T, D]
84
+ mask = attention_mask.unsqueeze(-1).float() # [B, T, 1]
85
+ summed = (token_embeddings * mask).sum(dim=1) # [B, D]
86
+ counts = mask.sum(dim=1).clamp(min=1e-6)
87
+ mean_pooled = summed / counts # [B, D]
88
+
89
+ return mean_pooled
90
+
91
+
92
+ def get_text_embeddings_with_diff(self, src_text: str, tgt_text: str, replacements: list[tuple[str, str, int, int]], show_tokens=False, return_embeds=False):
93
+ batch_encoding = self.tokenizer(
94
+ [src_text, tgt_text],
95
+ truncation=True,
96
+ max_length=self.max_length,
97
+ return_tensors="pt",
98
+ padding="max_length",
99
+ )
100
+
101
+ src_ids, tgt_ids = batch_encoding["input_ids"]
102
+
103
+ src_tokens = self.tokenizer.tokenize(src_text)
104
+ tgt_tokens = self.tokenizer.tokenize(tgt_text)
105
+ if show_tokens:
106
+ print("src tokens", src_tokens)
107
+ print("tgt tokens", tgt_tokens)
108
+
109
+ src_dif_ids = []
110
+ tgt_dif_ids = []
111
+ def find_mappings(tokens,words,start_idx):
112
+ if (words is None) or start_idx<0: # some samples do not need this
113
+ return [-1]
114
+ res = []
115
+ flag = 0
116
+ for i in range(start_idx,len(tokens)):
117
+ this_token = tokens[i].strip('▁')
118
+ if this_token == "":
119
+ continue
120
+ if words.startswith(this_token):
121
+ res.append(i)
122
+ flag = 1
123
+ if words.endswith(this_token):
124
+ break
125
+ else:
126
+ continue
127
+ if flag and words.endswith(this_token):
128
+ res.append(i)
129
+ break
130
+ if flag:
131
+ res.append(i)
132
+ return res
133
+
134
+ for src_words, tgt_words, src_index, tgt_index in replacements:
135
+ if src_words:
136
+ src_dif_ids.append(find_mappings(src_tokens,src_words,src_index))
137
+ else:
138
+ src_dif_ids.append([-1])
139
+ if tgt_words:
140
+ tgt_dif_ids.append(find_mappings(tgt_tokens,tgt_words,tgt_index))
141
+ else:
142
+ tgt_dif_ids.append([-1])
143
+
144
+ if return_embeds:
145
+ outputs = self.hf_module(
146
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
147
+ attention_mask=None,
148
+ output_hidden_states=False,
149
+ )
150
+ embeddings = outputs[self.output_key]
151
+ else:
152
+ embeddings = (None,None)
153
+ return embeddings[0], embeddings[1], src_dif_ids, tgt_dif_ids
154
+
155
+
flux/modules/layers_lore.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from flux.math import attention, rope, attention_with_attnmap
9
+
10
+ import os
11
+
12
+ class EmbedND(nn.Module):
13
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.theta = theta
17
+ self.axes_dim = axes_dim
18
+
19
+ def forward(self, ids: Tensor) -> Tensor:
20
+ n_axes = ids.shape[-1]
21
+ emb = torch.cat(
22
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
23
+ dim=-3,
24
+ )
25
+
26
+ return emb.unsqueeze(1)
27
+
28
+
29
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
30
+ """
31
+ Create sinusoidal timestep embeddings.
32
+ :param t: a 1-D Tensor of N indices, one per batch element.
33
+ These may be fractional.
34
+ :param dim: the dimension of the output.
35
+ :param max_period: controls the minimum frequency of the embeddings.
36
+ :return: an (N, D) Tensor of positional embeddings.
37
+ """
38
+ t = time_factor * t
39
+ half = dim // 2
40
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
41
+ t.device
42
+ )
43
+
44
+ args = t[:, None].float() * freqs[None]
45
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
46
+ if dim % 2:
47
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
48
+ if torch.is_floating_point(t):
49
+ embedding = embedding.to(t)
50
+ return embedding
51
+
52
+
53
+ class MLPEmbedder(nn.Module):
54
+ def __init__(self, in_dim: int, hidden_dim: int):
55
+ super().__init__()
56
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
57
+ self.silu = nn.SiLU()
58
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
59
+
60
+ def forward(self, x: Tensor) -> Tensor:
61
+ return self.out_layer(self.silu(self.in_layer(x)))
62
+
63
+
64
+ class RMSNorm(torch.nn.Module):
65
+ def __init__(self, dim: int):
66
+ super().__init__()
67
+ self.scale = nn.Parameter(torch.ones(dim))
68
+
69
+ def forward(self, x: Tensor):
70
+ x_dtype = x.dtype
71
+ x = x.float()
72
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
73
+ return (x * rrms).to(dtype=x_dtype) * self.scale
74
+
75
+
76
+ class QKNorm(torch.nn.Module):
77
+ def __init__(self, dim: int):
78
+ super().__init__()
79
+ self.query_norm = RMSNorm(dim)
80
+ self.key_norm = RMSNorm(dim)
81
+
82
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
83
+ q = self.query_norm(q)
84
+ k = self.key_norm(k)
85
+ return q.to(v), k.to(v)
86
+
87
+
88
+ class SelfAttention(nn.Module):
89
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
90
+ super().__init__()
91
+ self.num_heads = num_heads
92
+ head_dim = dim // num_heads
93
+
94
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
95
+ self.norm = QKNorm(head_dim)
96
+ self.proj = nn.Linear(dim, dim)
97
+
98
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
99
+ qkv = self.qkv(x)
100
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
101
+ q, k = self.norm(q, k, v)
102
+ x = attention(q, k, v, pe=pe)
103
+ x = self.proj(x)
104
+ return x
105
+
106
+
107
+ @dataclass
108
+ class ModulationOut:
109
+ shift: Tensor
110
+ scale: Tensor
111
+ gate: Tensor
112
+
113
+
114
+ class Modulation(nn.Module):
115
+ def __init__(self, dim: int, double: bool):
116
+ super().__init__()
117
+ self.is_double = double
118
+ self.multiplier = 6 if double else 3
119
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
120
+
121
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
122
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
123
+
124
+ return (
125
+ ModulationOut(*out[:3]),
126
+ ModulationOut(*out[3:]) if self.is_double else None,
127
+ )
128
+
129
+
130
+ class DoubleStreamBlock(nn.Module):
131
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
132
+ super().__init__()
133
+
134
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
135
+ self.num_heads = num_heads
136
+ self.hidden_size = hidden_size
137
+ self.img_mod = Modulation(hidden_size, double=True)
138
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
139
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
140
+
141
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
142
+ self.img_mlp = nn.Sequential(
143
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
144
+ nn.GELU(approximate="tanh"),
145
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
146
+ )
147
+
148
+ self.txt_mod = Modulation(hidden_size, double=True)
149
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
150
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
151
+
152
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
153
+ self.txt_mlp = nn.Sequential(
154
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
155
+ nn.GELU(approximate="tanh"),
156
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
157
+ )
158
+
159
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, info) -> tuple[Tensor, Tensor]:
160
+ img_mod1, img_mod2 = self.img_mod(vec)
161
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
162
+
163
+ # prepare image for attention
164
+ img_modulated = self.img_norm1(img)
165
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
166
+ img_qkv = self.img_attn.qkv(img_modulated)
167
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
168
+
169
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
170
+
171
+ # prepare txt for attention
172
+ txt_modulated = self.txt_norm1(txt)
173
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
174
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
175
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
176
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
177
+
178
+ # run actual attention
179
+ q = torch.cat((txt_q, img_q), dim=2) #[8, 24, 512, 128] + [8, 24, 900, 128] -> [8, 24, 1412, 128]
180
+ k = torch.cat((txt_k, img_k), dim=2)
181
+ v = torch.cat((txt_v, img_v), dim=2)
182
+ attn,attn_map = attention_with_attnmap(q, k, v, pe=pe)
183
+ attn_map = attn_map[:, :, txt.shape[1]:, :txt.shape[1]] # text to image attn map
184
+ if 'txt_token_l' in info:
185
+ # drop all paddings
186
+ attn_map = attn_map[:,:,:,:info['txt_token_l']]
187
+ attn_map = torch.nn.functional.softmax(attn_map, dim=-1) # softmax
188
+ attn_map = attn_map.mean(dim=1) # avg all head(24 head)
189
+
190
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
191
+
192
+ # calculate the img bloks
193
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
194
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
195
+
196
+ # calculate the txt bloks
197
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
198
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
199
+ return img, txt, attn_map
200
+
201
+
202
+ class SingleStreamBlock(nn.Module):
203
+ """
204
+ A DiT block with parallel linear layers as described in
205
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ hidden_size: int,
211
+ num_heads: int,
212
+ mlp_ratio: float = 4.0,
213
+ qk_scale: float | None = None,
214
+ ):
215
+ super().__init__()
216
+ self.hidden_dim = hidden_size
217
+ self.num_heads = num_heads
218
+ head_dim = hidden_size // num_heads
219
+ self.scale = qk_scale or head_dim**-0.5
220
+
221
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
222
+ # qkv and mlp_in
223
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
224
+ # proj and mlp_out
225
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
226
+
227
+ self.norm = QKNorm(head_dim)
228
+
229
+ self.hidden_size = hidden_size
230
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
231
+
232
+ self.mlp_act = nn.GELU(approximate="tanh")
233
+ self.modulation = Modulation(hidden_size, double=False)
234
+
235
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, info) -> Tensor:
236
+ mod, _ = self.modulation(vec)
237
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
238
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
239
+
240
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
241
+ q, k = self.norm(q, k, v)
242
+
243
+ # Note: If the memory of your device is not enough, you may consider uncomment the following code.
244
+ # if info['inject'] and info['id'] > 19:
245
+ # store_path = os.path.join(info['feature_path'], str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'V' + '.pth')
246
+ # if info['inverse']:
247
+ # torch.save(v, store_path)
248
+ # if not info['inverse']:
249
+ # v = torch.load(store_path, weights_only=True)
250
+
251
+ # Save the features in the memory # ori: 19
252
+ if info['inject'] and info['id'] > 19:
253
+ if 'ref' not in info:
254
+ info['ref'] = False
255
+ feature_name = str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + str(info['ref']) + '_' + 'V'
256
+ if info['inverse']:
257
+ info['feature'][feature_name] = v.cpu()
258
+ else:
259
+ # v injection with mask
260
+ # 0: original RF-Edit
261
+ # 1: new_v_text + old_v_image
262
+ # 2: new_v*mask + old_v*(1-mask)
263
+ if info['change_v'] == 0:
264
+ v = info['feature'][feature_name].cuda()
265
+ elif info['change_v'] == 1:
266
+ old_v = info['feature'][feature_name].cuda()
267
+ v = torch.cat([v[:, :, :512, :], old_v[:, :, 512:, :]], dim=2)
268
+ elif info['change_v'] == 2:
269
+ old_v = info['feature'][feature_name].cuda()
270
+ v = v * info['v_mask'] + old_v * (1 - info['v_mask'])
271
+
272
+
273
+
274
+ # compute attention
275
+ attn,attn_map = attention_with_attnmap(q, k, v, pe=pe)
276
+ attn_map = attn_map[:, :, 512:, :512] # text to image attn map
277
+ if 'txt_token_l' in info:
278
+ # drop all paddings
279
+ attn_map = attn_map[:,:,:,:info['txt_token_l']]
280
+ attn_map = torch.nn.functional.softmax(attn_map, dim=-1) # softmax
281
+ attn_map = attn_map.mean(dim=1) # avg all head(24 head)
282
+ # compute activation in mlp stream, cat again and run second linear layer
283
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
284
+ return x + mod.gate * output, info, attn_map
285
+
286
+
287
+ class LastLayer(nn.Module):
288
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
289
+ super().__init__()
290
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
291
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
292
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
293
+
294
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
295
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
296
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
297
+ x = self.linear(x)
298
+ return x
flux/sampling_lore.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import copy
4
+ from typing import Callable
5
+
6
+ import torch
7
+ import numpy as np
8
+ from einops import rearrange, repeat
9
+ from torch import Tensor
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ import torch.nn.functional as F
13
+ from PIL import Image
14
+ from torchvision import transforms
15
+
16
+ from .model_lore import Flux
17
+ from .modules.conditioner_lore import HFEmbedder
18
+
19
+ def prepare_tokens(t5, source_prompt, target_prompt, replacements,show_tokens=False):
20
+ _, _, src_dif_ids, tgt_dif_ids=t5.get_text_embeddings_with_diff(source_prompt,target_prompt,replacements,show_tokens=show_tokens)
21
+ return src_dif_ids,tgt_dif_ids
22
+
23
+ transform = transforms.ToTensor()
24
+
25
+ def get_mask_one_tensor(mask_dirs,width,height,device):
26
+ res = []
27
+ for mask_dir in mask_dirs:
28
+ mask_image = Image.open(mask_dir).convert('L')
29
+ # resize
30
+ mask_image = mask_image.resize((math.ceil(height/16), math.ceil(width/16)), Image.Resampling.LANCZOS)
31
+ mask_tensor = transform(mask_image)
32
+ mask_tensor = mask_tensor.squeeze(0)
33
+ # to one dim
34
+ mask_tensor = mask_tensor.flatten()
35
+ mask_tensor = mask_tensor.to(device)
36
+ res.append(mask_tensor)
37
+ res = sum(res)
38
+ res = res.view(1, 1, -1, 1)
39
+ res = res.to(torch.bfloat16)
40
+ return res
41
+
42
+ def get_v_mask(mask_dirs,width,height,device,txt_length=512):
43
+ res = []
44
+ for mask_dir in mask_dirs:
45
+ mask_image = Image.open(mask_dir).convert('L')
46
+ # resize
47
+ mask_image = mask_image.resize((math.ceil(height/16), math.ceil(width/16)), Image.Resampling.LANCZOS)
48
+ mask_tensor = transform(mask_image)
49
+ mask_tensor = mask_tensor.squeeze(0)
50
+ # to one dim
51
+ mask_tensor = mask_tensor.flatten()
52
+ mask_tensor = mask_tensor.to(device)
53
+ res.append(mask_tensor)
54
+ res = sum(res)
55
+ res = torch.cat([torch.ones(txt_length).to(device),res])
56
+ res = res.view(1, 1, -1, 1)
57
+ res = res.to(torch.bfloat16)
58
+ return res
59
+
60
+ def add_masked_noise_to_z(z,mask,width,height,seed=42,noise_scale=0.1):
61
+ if noise_scale == 0:
62
+ return z
63
+ noise = torch.randn(z.shape,device=z.device,dtype=z.dtype,generator=torch.Generator(device=z.device).manual_seed(seed))
64
+ if noise_scale > 10:
65
+ return noise
66
+ # how to change z?
67
+ z = z*(1-mask[0])+noise_scale*noise*mask[0]+(1-noise_scale)*z*mask[0]
68
+ return z
69
+
70
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
71
+ bs, c, h, w = img.shape
72
+ if bs == 1 and not isinstance(prompt, str):
73
+ bs = len(prompt)
74
+
75
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
76
+ if img.shape[0] == 1 and bs > 1:
77
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
78
+
79
+ img_ids = torch.zeros(h // 2, w // 2, 3)
80
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
81
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
82
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
83
+
84
+ if isinstance(prompt, str):
85
+ prompt = [prompt]
86
+ txt = t5(prompt)
87
+ if txt.shape[0] == 1 and bs > 1:
88
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
89
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
90
+
91
+ vec = clip(prompt)
92
+ if vec.shape[0] == 1 and bs > 1:
93
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
94
+
95
+ return {
96
+ "img": img,
97
+ "img_ids": img_ids.to(img.device),
98
+ "txt": txt.to(img.device),
99
+ "txt_ids": txt_ids.to(img.device),
100
+ "vec": vec.to(img.device),
101
+ }
102
+
103
+
104
+ def time_shift(mu: float, sigma: float, t: Tensor):
105
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
106
+
107
+
108
+ def get_lin_function(
109
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
110
+ ) -> Callable[[float], float]:
111
+ m = (y2 - y1) / (x2 - x1)
112
+ b = y1 - m * x1
113
+ return lambda x: m * x + b
114
+
115
+
116
+ def get_schedule(
117
+ num_steps: int,
118
+ image_seq_len: int,
119
+ base_shift: float = 0.5,
120
+ max_shift: float = 1.15,
121
+ shift: bool = True,
122
+ ) -> list[float]:
123
+ # extra step for zero
124
+ timesteps = torch.linspace(1, 0, num_steps + 1)
125
+
126
+ # shifting the schedule to favor high timesteps for higher signal images
127
+ if shift:
128
+ # estimate mu based on linear estimation between two points
129
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
130
+ timesteps = time_shift(mu, 1.0, timesteps)
131
+
132
+ return timesteps.tolist()
133
+
134
+ def denoise(
135
+ model: Flux,
136
+ # model input
137
+ img: Tensor,
138
+ img_ids: Tensor,
139
+ txt: Tensor,
140
+ txt_ids: Tensor,
141
+ vec: Tensor,
142
+ # sampling parameters
143
+ timesteps: list[float],
144
+ inverse,
145
+ info,
146
+ guidance: float = 4.0,
147
+ trainable_noise_list=None,
148
+ ):
149
+ # this is ignored for schnell
150
+ inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])
151
+
152
+
153
+ if inverse:
154
+ timesteps = timesteps[::-1]
155
+ inject_list = inject_list[::-1]
156
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
157
+
158
+ step_list = []
159
+ attn_map_list = []
160
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
161
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
162
+ info['t'] = t_prev if inverse else t_curr
163
+ info['inverse'] = inverse
164
+ info['second_order'] = False
165
+ info['inject'] = inject_list[i]
166
+ # when editing add optim latent for several steps
167
+ if trainable_noise_list and i != 0 and i<len(trainable_noise_list):
168
+ # smask = info['source_mask'].squeeze(0)
169
+ # img = trainable_noise_list[i]*smask+img*(1-smask)
170
+ img = trainable_noise_list[i]
171
+
172
+ pred, info, attn_maps_mid = model(
173
+ img=img,
174
+ img_ids=img_ids,
175
+ txt=txt,
176
+ txt_ids=txt_ids,
177
+ y=vec,
178
+ timesteps=t_vec,
179
+ guidance=guidance_vec,
180
+ info=info
181
+ )
182
+
183
+ img_mid = img + (t_prev - t_curr) / 2 * pred
184
+
185
+ t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
186
+ info['second_order'] = True
187
+ pred_mid, info, attn_maps = model(
188
+ img=img_mid,
189
+ img_ids=img_ids,
190
+ txt=txt,
191
+ txt_ids=txt_ids,
192
+ y=vec,
193
+ timesteps=t_vec_mid,
194
+ guidance=guidance_vec,
195
+ info=info
196
+ )
197
+
198
+ first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2)
199
+ img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order
200
+
201
+ # return attnmaps L,1,512,N
202
+ step_list.append(t_curr)
203
+ attn_map_list.append((attn_maps_mid+attn_maps)/2)
204
+
205
+ attn_map_list = torch.stack(attn_map_list)
206
+ return img, info, step_list, attn_map_list
207
+
208
+ selected_layers = range(8,44)
209
+
210
+ def gaussian_smooth(attnmap,wh,kernel_size=3,sigma=0.5):
211
+ # to 2d
212
+ attnmap = rearrange(
213
+ attnmap,
214
+ "b (w h) -> b (w) (h)",
215
+ w=math.ceil(wh[0]/16),
216
+ h=math.ceil(wh[1]/16),
217
+ )
218
+ attnmap = attnmap.unsqueeze(1)
219
+ # prepare kernel
220
+ ax = torch.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1., device=attnmap.device)
221
+ xx, yy = torch.meshgrid(ax, ax, indexing='ij')
222
+ kernel = torch.exp(-(xx**2 + yy**2) / (2. * sigma**2))
223
+ kernel = kernel / kernel.sum()
224
+ kernel = kernel.view(1, 1, kernel_size, kernel_size)
225
+ kernel = kernel.to(dtype=attnmap.dtype)
226
+ # gaussian smooth
227
+ attnmap_smoothed = F.conv2d(attnmap, kernel, padding=kernel_size // 2)
228
+ return attnmap_smoothed.view(attnmap_smoothed.shape[0], -1)
229
+
230
+ def compute_attn_max_loss(attnmaps,source_mask,wh):
231
+ # attnmaps L,1,N,k
232
+ attnmaps = attnmaps[selected_layers,0,:,:]
233
+ attnmaps = attnmaps.mean(dim=-1)
234
+ src_mask = source_mask.view(-1).unsqueeze(0)
235
+ p = attnmaps*src_mask
236
+ p = gaussian_smooth(p, wh, kernel_size=3, sigma=0.5)
237
+ p = p.max(dim=1).values
238
+ loss = (1 - p).mean()
239
+ return loss
240
+
241
+ def compute_attn_min_loss(attnmaps,source_mask,wh):
242
+ # attnmaps L,1,N,k
243
+ attnmaps = attnmaps[selected_layers,0,:,:]
244
+ attnmaps = attnmaps.mean(dim=-1)
245
+ src_mask = source_mask.view(-1).unsqueeze(0)
246
+ p = attnmaps*src_mask
247
+ p = gaussian_smooth(p, wh, kernel_size=3, sigma=0.5)
248
+ p = p.max(dim=1).values
249
+ loss = p.mean()
250
+ return loss
251
+
252
+ def denoise_with_noise_optim(
253
+ model: Flux,
254
+ # model input
255
+ img: Tensor,
256
+ img_ids: Tensor,
257
+ txt: Tensor,
258
+ txt_ids: Tensor,
259
+ vec: Tensor,
260
+ # loss cal
261
+ token_ids: list[list[int]],
262
+ source_mask: Tensor,
263
+ training_steps: int,
264
+ training_epochs: int,
265
+ learning_rate: float,
266
+ seed: int,
267
+ noise_scale: float,
268
+ # sampling parameters
269
+ timesteps: list[float],
270
+ info,
271
+ guidance: float = 4.0
272
+ ):
273
+ # this is ignored for schnell
274
+ #print(f'training the noise in last {training_steps} steps and {training_epochs} epochs')
275
+ #timesteps = timesteps[::-1]
276
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
277
+
278
+ step_list = []
279
+ attn_map_list = []
280
+ trainable_noise_list = []
281
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
282
+ if i >= training_steps:
283
+ break
284
+ # prepare ori parameters
285
+ ori_txt = txt.clone()
286
+ ori_img = img.clone()
287
+ ori_vec = vec.clone()
288
+
289
+ # prepare trainable noise
290
+ if i == 0:
291
+ if noise_scale == 0:
292
+ trainable_noise = torch.nn.Parameter(img.clone().detach(), requires_grad=True)
293
+ else:
294
+ noise = torch.randn(img.shape,device=img.device,dtype=img.dtype,generator=torch.Generator(device=img.device).manual_seed(seed))
295
+ noise = img*(1-source_mask[0])+ noise_scale*noise*source_mask[0] + (1-noise_scale)*img*source_mask[0]
296
+ trainable_noise = torch.nn.Parameter(noise.clone().detach(), requires_grad=True)
297
+ else:
298
+ trainable_noise = torch.nn.Parameter(img.clone().detach(), requires_grad=True)
299
+ optimizer = optim.Adam([trainable_noise], lr=learning_rate)
300
+
301
+ # run one training step
302
+ for j in range(training_epochs):
303
+ optimizer.zero_grad()
304
+ txt = ori_txt.clone().detach()
305
+ vec = ori_vec.clone().detach()
306
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
307
+ info['t'] = t_prev
308
+ info['inverse'] = False
309
+ info['second_order'] = False
310
+ info['inject'] = False # tried True, seems not necessary
311
+ pred, info, attn_maps_mid = model(
312
+ img=trainable_noise,
313
+ img_ids=img_ids,
314
+ txt=txt,
315
+ txt_ids=txt_ids,
316
+ y=vec,
317
+ timesteps=t_vec,
318
+ guidance=guidance_vec,
319
+ info=info
320
+ )
321
+
322
+
323
+ img_mid = trainable_noise + (t_prev - t_curr) / 2 * pred
324
+
325
+ t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
326
+ info['second_order'] = True
327
+ pred_mid, info, attn_maps = model(
328
+ img=img_mid,
329
+ img_ids=img_ids,
330
+ txt=txt,
331
+ txt_ids=txt_ids,
332
+ y=vec,
333
+ timesteps=t_vec_mid,
334
+ guidance=guidance_vec,
335
+ info=info
336
+ )
337
+
338
+ first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2)
339
+ img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order
340
+
341
+ # attnmaps L,1,N,512 for cal loss
342
+ attn_maps=(attn_maps_mid+attn_maps)/2
343
+ total_loss = 0.0
344
+ for indices,change,ratio in token_ids:
345
+ if change:
346
+ total_loss += compute_attn_max_loss(attn_maps[:,:,:,indices], source_mask, info['wh'])
347
+ else:
348
+ if ratio != 0:
349
+ total_loss += ratio*compute_attn_min_loss(attn_maps[:,:,:,indices], source_mask, info['wh'])
350
+ total_loss.backward()
351
+ with torch.no_grad():
352
+ trainable_noise.grad *= source_mask[0]
353
+ optimizer.step()
354
+ print(f"Time {t_curr:.4f} Step {j+1}/{training_epochs}, Loss: {total_loss.item():.6f}")
355
+
356
+ attn_map_list.append(attn_maps.detach())
357
+ step_list.append(t_curr)
358
+ trainable_noise = trainable_noise.detach()
359
+ trainable_noise_list.append(trainable_noise.clone())
360
+
361
+ attn_map_list = torch.stack(attn_map_list)
362
+ return img, info, step_list, attn_map_list, trainable_noise_list
363
+
364
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
365
+ return rearrange(
366
+ x,
367
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
368
+ h=math.ceil(height / 16),
369
+ w=math.ceil(width / 16),
370
+ ph=2,
371
+ pw=2,
372
+ )
flux/util_lore.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from imwatermark import WatermarkEncoder
8
+ from safetensors.torch import load_file as load_sft
9
+
10
+ from flux.model_lore import Flux, FluxParams
11
+ from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
12
+ from flux.modules.conditioner_lore import HFEmbedder
13
+
14
+
15
+ @dataclass
16
+ class ModelSpec:
17
+ params: FluxParams
18
+ ae_params: AutoEncoderParams
19
+ ckpt_path: str | None
20
+ ae_path: str | None
21
+ repo_id: str | None
22
+ repo_flow: str | None
23
+ repo_ae: str | None
24
+
25
+ # download model from hf
26
+ flux_path = "black-forest-labs/FLUX.1-dev"
27
+ flux_ckpt_path = os.getenv("FLUX_DEV")
28
+ flux_ae_path = os.getenv("AE")
29
+ t5_path = "google/t5-v1_1-xxl"
30
+ clip_path = "openai/clip-vit-large-patch14"
31
+
32
+ configs = {
33
+ "flux-dev": ModelSpec(
34
+ repo_id=flux_path,
35
+ repo_flow="flux1-dev.safetensors",
36
+ repo_ae="ae.safetensors",
37
+ ckpt_path=flux_ckpt_path,
38
+ params=FluxParams(
39
+ in_channels=64,
40
+ vec_in_dim=768,
41
+ context_in_dim=4096,
42
+ hidden_size=3072,
43
+ mlp_ratio=4.0,
44
+ num_heads=24,
45
+ depth=19,
46
+ depth_single_blocks=38,
47
+ axes_dim=[16, 56, 56],
48
+ theta=10_000,
49
+ qkv_bias=True,
50
+ guidance_embed=True,
51
+ ),
52
+ ae_path=flux_ae_path,
53
+ ae_params=AutoEncoderParams(
54
+ resolution=256,
55
+ in_channels=3,
56
+ ch=128,
57
+ out_ch=3,
58
+ ch_mult=[1, 2, 4, 4],
59
+ num_res_blocks=2,
60
+ z_channels=16,
61
+ scale_factor=0.3611,
62
+ shift_factor=0.1159,
63
+ ),
64
+ ),
65
+ "flux-schnell": ModelSpec(
66
+ repo_id="black-forest-labs/FLUX.1-schnell",
67
+ repo_flow="flux1-schnell.safetensors",
68
+ repo_ae="ae.safetensors",
69
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
70
+ params=FluxParams(
71
+ in_channels=64,
72
+ vec_in_dim=768,
73
+ context_in_dim=4096,
74
+ hidden_size=3072,
75
+ mlp_ratio=4.0,
76
+ num_heads=24,
77
+ depth=19,
78
+ depth_single_blocks=38,
79
+ axes_dim=[16, 56, 56],
80
+ theta=10_000,
81
+ qkv_bias=True,
82
+ guidance_embed=False,
83
+ ),
84
+ ae_path=os.getenv("AE"),
85
+ ae_params=AutoEncoderParams(
86
+ resolution=256,
87
+ in_channels=3,
88
+ ch=128,
89
+ out_ch=3,
90
+ ch_mult=[1, 2, 4, 4],
91
+ num_res_blocks=2,
92
+ z_channels=16,
93
+ scale_factor=0.3611,
94
+ shift_factor=0.1159,
95
+ ),
96
+ ),
97
+ }
98
+
99
+
100
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
101
+ if len(missing) > 0 and len(unexpected) > 0:
102
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
103
+ print("\n" + "-" * 79 + "\n")
104
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
105
+ elif len(missing) > 0:
106
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
107
+ elif len(unexpected) > 0:
108
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
109
+
110
+
111
+ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
112
+ # Loading Flux
113
+ print("Init model")
114
+
115
+ ckpt_path = configs[name].ckpt_path
116
+ if (
117
+ ckpt_path is None
118
+ and configs[name].repo_id is not None
119
+ and configs[name].repo_flow is not None
120
+ and hf_download
121
+ ):
122
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
123
+
124
+ with torch.device("meta" if ckpt_path is not None else device):
125
+ model = Flux(configs[name].params).to(torch.bfloat16)
126
+
127
+ if ckpt_path is not None:
128
+ print("Loading checkpoint on", device, ckpt_path)
129
+ # load_sft doesn't support torch.device
130
+ sd = load_sft(ckpt_path, device=str(device))
131
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
132
+ print_load_warning(missing, unexpected)
133
+ return model
134
+
135
+
136
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
137
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
138
+ return HFEmbedder(t5_path, max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
139
+
140
+
141
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
142
+ return HFEmbedder(clip_path, max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
143
+
144
+
145
+ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
146
+ ckpt_path = configs[name].ae_path
147
+ if (
148
+ ckpt_path is None
149
+ and configs[name].repo_id is not None
150
+ and configs[name].repo_ae is not None
151
+ and hf_download
152
+ ):
153
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
154
+
155
+ # Loading the autoencoder
156
+ print("Init AE")
157
+ with torch.device("meta" if ckpt_path is not None else device):
158
+ ae = AutoEncoder(configs[name].ae_params)
159
+
160
+ if ckpt_path is not None:
161
+ sd = load_sft(ckpt_path, device=str(device))
162
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
163
+ print_load_warning(missing, unexpected)
164
+ return ae
165
+
166
+
167
+ class WatermarkEmbedder:
168
+ def __init__(self, watermark):
169
+ self.watermark = watermark
170
+ self.num_bits = len(WATERMARK_BITS)
171
+ self.encoder = WatermarkEncoder()
172
+ self.encoder.set_watermark("bits", self.watermark)
173
+
174
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
175
+ """
176
+ Adds a predefined watermark to the input image
177
+
178
+ Args:
179
+ image: ([N,] B, RGB, H, W) in range [-1, 1]
180
+
181
+ Returns:
182
+ same as input but watermarked
183
+ """
184
+ image = 0.5 * image + 0.5
185
+ squeeze = len(image.shape) == 4
186
+ if squeeze:
187
+ image = image[None, ...]
188
+ n = image.shape[0]
189
+ image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
190
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
191
+ # watermarking libary expects input as cv2 BGR format
192
+ for k in range(image_np.shape[0]):
193
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
194
+ image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
195
+ image.device
196
+ )
197
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
198
+ if squeeze:
199
+ image = image[0]
200
+ image = 2 * image - 1
201
+ return image
202
+
203
+
204
+ # A fixed 48-bit message that was chosen at random
205
+ WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
206
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
207
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
208
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pydantic==2.10.6
2
+ torch
3
+ einops
4
+ accelerate==0.34.2
5
+ einops==0.8.0
6
+ transformers==4.41.2
7
+ huggingface-hub==0.24.6
8
+ datasets
9
+ omegaconf
10
+ diffusers
11
+ sentencepiece
12
+ opencv-python
13
+ matplotlib
14
+ onnxruntime
15
+ torchvision
16
+ timm
17
+ invisible-watermark
18
+ fire
19
+ tqdm