Spaces:
Running
on
Zero
Running
on
Zero
oyly
commited on
Commit
·
87fa4fd
1
Parent(s):
a8d4753
first commit
Browse files- .gitattributes +1 -0
- .gitignore +3 -0
- app.py +375 -0
- examples/car.png +3 -0
- examples/car_mask.png +3 -0
- examples/cup.png +3 -0
- examples/cup_mask.png +3 -0
- examples/woman.png +3 -0
- examples/woman_mask.png +3 -0
- flux/__init__.py +11 -0
- flux/__main__.py +4 -0
- flux/_version.py +21 -0
- flux/api.py +194 -0
- flux/math.py +57 -0
- flux/model_lore.py +124 -0
- flux/modules/autoencoder.py +313 -0
- flux/modules/conditioner_lore.py +155 -0
- flux/modules/layers_lore.py +298 -0
- flux/sampling_lore.py +372 -0
- flux/util_lore.py +208 -0
- requirements.txt +19 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.ipynb_checkpoints/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
app.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import time
|
| 4 |
+
import math
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from glob import iglob
|
| 7 |
+
import argparse
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from PIL import ExifTags, Image
|
| 10 |
+
import torch
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import numpy as np
|
| 13 |
+
import spaces
|
| 14 |
+
from huggingface_hub import login
|
| 15 |
+
login(token=os.getenv('Token'))
|
| 16 |
+
from flux.sampling_lore import denoise, get_schedule, prepare, unpack, get_v_mask, add_masked_noise_to_z,get_mask_one_tensor, denoise_with_noise_optim,prepare_tokens
|
| 17 |
+
from flux.util_lore import (configs, embed_watermark, load_ae, load_clip,
|
| 18 |
+
load_flow_model, load_t5)
|
| 19 |
+
|
| 20 |
+
def encode(init_image, torch_device, ae):
|
| 21 |
+
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
|
| 22 |
+
init_image = init_image.unsqueeze(0)
|
| 23 |
+
init_image = init_image.to(torch_device)
|
| 24 |
+
init_image = ae.encode(init_image.to()).to(torch.bfloat16)
|
| 25 |
+
return init_image
|
| 26 |
+
from torchvision import transforms
|
| 27 |
+
transform = transforms.ToTensor()
|
| 28 |
+
|
| 29 |
+
class FluxEditor_lore_demo:
|
| 30 |
+
def __init__(self, model_name):
|
| 31 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
+
self.offload = False
|
| 33 |
+
|
| 34 |
+
self.name = model_name
|
| 35 |
+
self.is_schnell = model_name == "flux-schnell"
|
| 36 |
+
self.resize_longside = 800
|
| 37 |
+
self.save = False
|
| 38 |
+
|
| 39 |
+
self.output_dir = 'outputs_gradio'
|
| 40 |
+
|
| 41 |
+
self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
|
| 42 |
+
self.clip = load_clip(self.device)
|
| 43 |
+
self.model = load_flow_model(model_name, device=self.device)
|
| 44 |
+
self.ae = load_ae(self.name, device="cpu" if self.offload else self.device)
|
| 45 |
+
|
| 46 |
+
self.t5.eval()
|
| 47 |
+
self.clip.eval()
|
| 48 |
+
self.ae.eval()
|
| 49 |
+
self.info = {}
|
| 50 |
+
if self.offload:
|
| 51 |
+
self.model.cpu()
|
| 52 |
+
torch.cuda.empty_cache()
|
| 53 |
+
self.ae.encoder.to(self.device)
|
| 54 |
+
for param in self.model.parameters():
|
| 55 |
+
param.requires_grad = False # freeze the model
|
| 56 |
+
for param in self.t5.parameters():
|
| 57 |
+
param.requires_grad = False # freeze the model
|
| 58 |
+
for param in self.clip.parameters():
|
| 59 |
+
param.requires_grad = False # freeze the model
|
| 60 |
+
for param in self.ae.parameters():
|
| 61 |
+
param.requires_grad = False # freeze the model
|
| 62 |
+
|
| 63 |
+
def resize_image(self,image):
|
| 64 |
+
pil_image = Image.fromarray(image)
|
| 65 |
+
h, w = pil_image.size[1], pil_image.size[0]
|
| 66 |
+
if h <= self.resize_longside and w <= self.resize_longside:
|
| 67 |
+
return image
|
| 68 |
+
|
| 69 |
+
if h >= w:
|
| 70 |
+
new_h = self.resize_longside
|
| 71 |
+
new_w = int(w * self.resize_longside / h)
|
| 72 |
+
else:
|
| 73 |
+
new_w = self.resize_longside
|
| 74 |
+
new_h = int(h * self.resize_longside / w)
|
| 75 |
+
|
| 76 |
+
resized_image = pil_image.resize((new_w, new_h), Image.LANCZOS)
|
| 77 |
+
return np.array(resized_image)
|
| 78 |
+
|
| 79 |
+
def resize_mask(self,mask,height,width):
|
| 80 |
+
pil_mask = Image.fromarray(mask.astype(np.uint8)) # ensure it's 8-bit for PIL
|
| 81 |
+
resized_pil = pil_mask.resize((width, height), Image.NEAREST) # width first!
|
| 82 |
+
return np.array(resized_pil)
|
| 83 |
+
|
| 84 |
+
@spaces.GPU(duration=240)
|
| 85 |
+
def inverse(self, brush_canvas,src_prompt,
|
| 86 |
+
inversion_num_steps, injection_num_steps,
|
| 87 |
+
inversion_guidance,
|
| 88 |
+
):
|
| 89 |
+
print(f"Inversing {src_prompt}, guidance {inversion_guidance}, inje/step {injection_num_steps}/{inversion_num_steps}")
|
| 90 |
+
self.z0 = None
|
| 91 |
+
self.zt = None
|
| 92 |
+
torch.cuda.empty_cache()
|
| 93 |
+
if self.info:
|
| 94 |
+
del self.info
|
| 95 |
+
self.info = {'src_p':src_prompt}
|
| 96 |
+
|
| 97 |
+
rgba_init_image = brush_canvas["background"]
|
| 98 |
+
init_image = rgba_init_image[:,:,:3]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if self.resize_longside != -1:
|
| 102 |
+
init_image = self.resize_image(init_image)
|
| 103 |
+
shape = init_image.shape
|
| 104 |
+
|
| 105 |
+
new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
|
| 106 |
+
new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
|
| 107 |
+
|
| 108 |
+
init_image = init_image[:new_h, :new_w, :]
|
| 109 |
+
width, height = init_image.shape[0], init_image.shape[1]
|
| 110 |
+
self.init_image = encode(init_image, self.device, self.ae)
|
| 111 |
+
|
| 112 |
+
if self.save:
|
| 113 |
+
ori_output_path = os.path.join(self.output_dir,f'{src_prompt[:20]}_ori.png')
|
| 114 |
+
Image.fromarray(init_image,'RGB').save(ori_output_path)
|
| 115 |
+
|
| 116 |
+
t0 = time.perf_counter()
|
| 117 |
+
|
| 118 |
+
self.info['feature'] = {}
|
| 119 |
+
self.info['inject_step'] = injection_num_steps
|
| 120 |
+
self.info['wh'] = (width, height)
|
| 121 |
+
|
| 122 |
+
torch.cuda.empty_cache()
|
| 123 |
+
|
| 124 |
+
inp = prepare(self.t5, self.clip, self.init_image, prompt=src_prompt)
|
| 125 |
+
timesteps = get_schedule(inversion_num_steps, inp["img"].shape[1], shift=True)
|
| 126 |
+
self.info['x_ori'] = inp["img"].clone()
|
| 127 |
+
|
| 128 |
+
# inversion initial noise
|
| 129 |
+
torch.set_grad_enabled(False)
|
| 130 |
+
z, info, _, _ = denoise(self.model, **inp, timesteps=timesteps, guidance=inversion_guidance, inverse=True, info=self.info)
|
| 131 |
+
self.z0 = z
|
| 132 |
+
self.info = info
|
| 133 |
+
|
| 134 |
+
t1 = time.perf_counter()
|
| 135 |
+
print(f"inversion Done in {t1 - t0:.1f}s.")
|
| 136 |
+
return init_image
|
| 137 |
+
|
| 138 |
+
@spaces.GPU(duration=240)
|
| 139 |
+
def edit(self, brush_canvas, source_prompt, inversion_guidance,
|
| 140 |
+
target_prompt, target_object,target_object_index,
|
| 141 |
+
inversion_num_steps, injection_num_steps,
|
| 142 |
+
training_epochs,
|
| 143 |
+
denoise_guidance,noise_scale,seed,
|
| 144 |
+
):
|
| 145 |
+
|
| 146 |
+
torch.cuda.empty_cache()
|
| 147 |
+
if 'src_p' not in self.info or self.info['src_p'] != source_prompt:
|
| 148 |
+
print('src prompt changed. inverse again')
|
| 149 |
+
self.inverse(brush_canvas,source_prompt,
|
| 150 |
+
inversion_num_steps, injection_num_steps,
|
| 151 |
+
inversion_guidance)
|
| 152 |
+
|
| 153 |
+
rgba_init_image = brush_canvas["background"]
|
| 154 |
+
rgba_mask = brush_canvas["layers"][0]
|
| 155 |
+
init_image = rgba_init_image[:,:,:3]
|
| 156 |
+
if self.resize_longside != -1:
|
| 157 |
+
init_image = self.resize_image(init_image)
|
| 158 |
+
width, height = self.info['wh']
|
| 159 |
+
init_image = init_image[:width, :height, :]
|
| 160 |
+
#rgba_init_image = rgba_init_image[:height, :width, :]
|
| 161 |
+
|
| 162 |
+
if self.resize_longside != -1:
|
| 163 |
+
mask = self.resize_mask(rgba_mask[:,:,3],height,width)
|
| 164 |
+
else:
|
| 165 |
+
mask = rgba_mask[:width, :height, 3]
|
| 166 |
+
mask = mask.astype(int)
|
| 167 |
+
|
| 168 |
+
rgba_mask[:,:,3] = rgba_mask[:,:,3]//2
|
| 169 |
+
masked_image = Image.alpha_composite(Image.fromarray(rgba_init_image, 'RGBA'), Image.fromarray(rgba_mask, 'RGBA'))
|
| 170 |
+
masked_image = masked_image.resize((height, width), Image.LANCZOS)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# prepare source mask and vmask
|
| 174 |
+
inp_optim = prepare(self.t5, self.clip, self.init_image, prompt=target_prompt)
|
| 175 |
+
inp_target = prepare(self.t5, self.clip, self.init_image, prompt=target_prompt)
|
| 176 |
+
v_mask,source_mask = self.get_v_src_masks(mask,width,height,self.device)
|
| 177 |
+
self.info['change_v'] = 2 # v_mask
|
| 178 |
+
self.info['v_mask'] = v_mask
|
| 179 |
+
self.info['source_mask'] = source_mask
|
| 180 |
+
self.info['inject_step'] = injection_num_steps
|
| 181 |
+
timesteps = get_schedule(inversion_num_steps, inp_optim["img"].shape[1], shift=True)
|
| 182 |
+
seed = int(seed)
|
| 183 |
+
if seed == -1:
|
| 184 |
+
seed = torch.randint(0, 2**32, (1,)).item()
|
| 185 |
+
|
| 186 |
+
# prepare token_ids
|
| 187 |
+
token_ids=[]
|
| 188 |
+
replacements = [[None,target_object,-1,int(target_object_index)]]
|
| 189 |
+
src_dif_ids,tgt_dif_ids = prepare_tokens(self.t5, source_prompt, target_prompt, replacements,True)
|
| 190 |
+
for t_ids in tgt_dif_ids:
|
| 191 |
+
token_ids.append([t_ids,True,1])
|
| 192 |
+
print('token_ids',token_ids)
|
| 193 |
+
# do latent optim
|
| 194 |
+
|
| 195 |
+
t0 = time.perf_counter()
|
| 196 |
+
print(f'optimizing & editing noise, {target_prompt} with seed {seed}, noise_scale {noise_scale}, training_epochs {training_epochs}')
|
| 197 |
+
if training_epochs != 0:
|
| 198 |
+
torch.set_grad_enabled(True)
|
| 199 |
+
inp_optim["img"] = self.z0
|
| 200 |
+
_, info, _, _, trainable_noise_list = denoise_with_noise_optim(self.model,**inp_optim,token_ids=token_ids,source_mask=source_mask,training_steps=1,training_epochs=training_epochs,learning_rate=0.01,seed=seed,noise_scale=noise_scale,timesteps=timesteps,info=self.info,guidance=denoise_guidance)
|
| 201 |
+
z_optim = trainable_noise_list[0]
|
| 202 |
+
self.info = info
|
| 203 |
+
else:
|
| 204 |
+
z_optim = add_masked_noise_to_z(self.z0,source_mask,width,height,seed=seed,noise_scale=noise_scale)
|
| 205 |
+
trainable_noise_list = None
|
| 206 |
+
|
| 207 |
+
# denoise (editing)
|
| 208 |
+
inp_target["img"] = z_optim
|
| 209 |
+
timesteps = get_schedule(inversion_num_steps, inp_target["img"].shape[1], shift=True)
|
| 210 |
+
self.model.eval()
|
| 211 |
+
torch.set_grad_enabled(False)
|
| 212 |
+
x, _, _, _ = denoise(self.model, **inp_target, timesteps=timesteps, guidance=denoise_guidance, inverse=False, info=self.info, trainable_noise_list = trainable_noise_list)
|
| 213 |
+
|
| 214 |
+
# decode latents to pixel space
|
| 215 |
+
batch_x = unpack(x.float(), width,height)
|
| 216 |
+
|
| 217 |
+
for x in batch_x:
|
| 218 |
+
x = x.unsqueeze(0)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
|
| 222 |
+
x = self.ae.decode(x)
|
| 223 |
+
|
| 224 |
+
if torch.cuda.is_available():
|
| 225 |
+
torch.cuda.synchronize()
|
| 226 |
+
# bring into PIL format and save
|
| 227 |
+
x = x.clamp(-1, 1)
|
| 228 |
+
x = embed_watermark(x.float())
|
| 229 |
+
x = rearrange(x[0], "c h w -> h w c")
|
| 230 |
+
|
| 231 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
| 232 |
+
exif_data = Image.Exif()
|
| 233 |
+
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
| 234 |
+
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
| 235 |
+
if self.save:
|
| 236 |
+
output_path = os.path.join(self.output_dir,f'{target_object}_{injection_num_steps:02d}_{inversion_num_steps}_seed_{seed}_epoch_{training_epochs:03d}_scale_{noise_scale:.2f}.png')
|
| 237 |
+
img.save(output_path, exif=exif_data, quality=95, subsampling=0)
|
| 238 |
+
masked_image.save(output_path.replace(target_object,f'{target_object}_masked'))
|
| 239 |
+
binary_mask = np.where(mask != 0, 255, 0).astype(np.uint8)
|
| 240 |
+
Image.fromarray(binary_mask, mode="L").save(output_path.replace(target_object,f'{target_object}_mask'))
|
| 241 |
+
t1 = time.perf_counter()
|
| 242 |
+
print(f"Done in {t1 - t0:.1f}s.", f'Saving {output_path} .' if self.save else 'No saving files.')
|
| 243 |
+
|
| 244 |
+
return img
|
| 245 |
+
|
| 246 |
+
def encode(self,init_image, torch_device):
|
| 247 |
+
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
|
| 248 |
+
init_image = init_image.unsqueeze(0)
|
| 249 |
+
init_image = init_image.to(torch_device)
|
| 250 |
+
self.ae.encoder.to(torch_device)
|
| 251 |
+
|
| 252 |
+
init_image = self.ae.encode(init_image).to(torch.bfloat16)
|
| 253 |
+
return init_image
|
| 254 |
+
|
| 255 |
+
def get_v_src_masks(self,mask,width,height,device,txt_length=512):
|
| 256 |
+
# resize mask to token size
|
| 257 |
+
mask = (mask > 127).astype(np.uint8)
|
| 258 |
+
mask = mask * 255
|
| 259 |
+
pil_mask = Image.fromarray(mask)
|
| 260 |
+
pil_mask = pil_mask.resize((math.ceil(height/16), math.ceil(width/16)), Image.Resampling.LANCZOS)
|
| 261 |
+
|
| 262 |
+
mask = transform(pil_mask)
|
| 263 |
+
mask = mask.flatten().to(device)
|
| 264 |
+
|
| 265 |
+
s_mask = mask.view(1, 1, -1, 1)
|
| 266 |
+
s_mask = s_mask.to(torch.bfloat16)
|
| 267 |
+
v_mask = torch.cat([torch.ones(txt_length).to(device),mask])
|
| 268 |
+
v_mask = v_mask.view(1, 1, -1, 1)
|
| 269 |
+
v_mask = v_mask.to(torch.bfloat16)
|
| 270 |
+
return v_mask,s_mask
|
| 271 |
+
|
| 272 |
+
def create_demo(model_name: str):
|
| 273 |
+
editor = FluxEditor_lore_demo(model_name)
|
| 274 |
+
is_schnell = model_name == "flux-schnell"
|
| 275 |
+
|
| 276 |
+
title = r"""
|
| 277 |
+
<h1 align="center">🎨 LORE Image Editing </h1>
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
description = r"""
|
| 281 |
+
<b>Official 🤗 Gradio demo</b> <br>
|
| 282 |
+
<b>LORE: Latent Optimization for Precise Semantic Control in Rectified Flow-based Image Editing.</b><br>
|
| 283 |
+
<b>Here are editing steps:</b> <br>
|
| 284 |
+
1️⃣ Upload your source image. <br>
|
| 285 |
+
2️⃣ Fill in your source prompt and click the "Inverse" button to perform image inversion. <br>
|
| 286 |
+
3️⃣ Use the brush tool to draw your mask. (on layer 1) <br>
|
| 287 |
+
4️⃣ Fill in your target prompt, then adjust the hyperparameters. <br>
|
| 288 |
+
5️⃣ Click the "Edit" button to generate your edited image! <br>
|
| 289 |
+
6️⃣ If source image and prompt are not changed, you can click 'Edit' for next generation. <br>
|
| 290 |
+
|
| 291 |
+
🔔 [<b>Note</b>] Due to limited resources, we will resize image to <=800 longside. <br>
|
| 292 |
+
"""
|
| 293 |
+
article = r"""
|
| 294 |
+
https://github.com/oyly16/LORE
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
with gr.Blocks() as demo:
|
| 298 |
+
gr.HTML(title)
|
| 299 |
+
gr.Markdown(description)
|
| 300 |
+
|
| 301 |
+
with gr.Row():
|
| 302 |
+
with gr.Column():
|
| 303 |
+
src_prompt = gr.Textbox(label="Source Prompt", value='' )
|
| 304 |
+
inversion_num_steps = gr.Slider(1, 50, 15, step=1, label="Number of inversion/denoise steps")
|
| 305 |
+
injection_num_steps = gr.Slider(1, 50, 12, step=1, label="Number of masked value injection steps")
|
| 306 |
+
target_prompt = gr.Textbox(label="Target Prompt", value='' )
|
| 307 |
+
target_object = gr.Textbox(label="Target Object", value='' )
|
| 308 |
+
target_object_index = gr.Textbox(label="Target Object Index (start index from 0 in target prompt)", value='' )
|
| 309 |
+
brush_canvas = gr.ImageEditor(label="Brush Canvas",
|
| 310 |
+
sources=('upload'),
|
| 311 |
+
brush=gr.Brush(colors=["#ff0000"],color_mode='fixed',default_color="#ff0000"),
|
| 312 |
+
interactive=True,
|
| 313 |
+
transforms=[],
|
| 314 |
+
container=True,
|
| 315 |
+
format='png',scale=1)
|
| 316 |
+
|
| 317 |
+
inv_btn = gr.Button("inverse")
|
| 318 |
+
edit_btn = gr.Button("edit")
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
with gr.Column():
|
| 322 |
+
with gr.Accordion("Advanced Options", open=True):
|
| 323 |
+
|
| 324 |
+
training_epochs = gr.Slider(0, 30, 10, step=1, label="Number of LORE training epochs")
|
| 325 |
+
inversion_guidance = gr.Slider(1.0, 10.0, 1.0, step=0.1, label="inversion Guidance", interactive=not is_schnell)
|
| 326 |
+
denoise_guidance = gr.Slider(1.0, 10.0, 2.0, step=0.1, label="denoise Guidance", interactive=not is_schnell)
|
| 327 |
+
noise_scale = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="renoise scale")
|
| 328 |
+
seed = gr.Textbox('0', label="Seed (-1 for random)", visible=True)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
output_image = gr.Image(label="Generated Image")
|
| 332 |
+
gr.Markdown(article)
|
| 333 |
+
inv_btn.click(
|
| 334 |
+
fn=editor.inverse,
|
| 335 |
+
inputs=[brush_canvas,src_prompt,
|
| 336 |
+
inversion_num_steps, injection_num_steps,
|
| 337 |
+
inversion_guidance,
|
| 338 |
+
],
|
| 339 |
+
outputs=[output_image]
|
| 340 |
+
)
|
| 341 |
+
edit_btn.click(
|
| 342 |
+
fn=editor.edit,
|
| 343 |
+
inputs=[brush_canvas,src_prompt,inversion_guidance,
|
| 344 |
+
target_prompt, target_object,target_object_index,
|
| 345 |
+
inversion_num_steps, injection_num_steps,
|
| 346 |
+
training_epochs,
|
| 347 |
+
denoise_guidance,noise_scale,seed,
|
| 348 |
+
],
|
| 349 |
+
outputs=[output_image]
|
| 350 |
+
)
|
| 351 |
+
gr.Examples(
|
| 352 |
+
examples=[
|
| 353 |
+
["examples/woman.png", "a young woman", 15, 12, "a young woman with a necklace", "necklace", "5", 10, 0.9, "3"],
|
| 354 |
+
["examples/car.png", "a taxi in a neon-lit street", 30, 24, "a race car in a neon-lit street", "race car", "1", 5, 0.1, "2388791121"],
|
| 355 |
+
["examples/cup.png", "a cup on a wooden table", 10, 8, "a wooden table", "table", "2", 2, 0, "0"],
|
| 356 |
+
],
|
| 357 |
+
inputs=[
|
| 358 |
+
brush_canvas,
|
| 359 |
+
src_prompt,
|
| 360 |
+
inversion_num_steps,
|
| 361 |
+
injection_num_steps,
|
| 362 |
+
target_prompt,
|
| 363 |
+
target_object,
|
| 364 |
+
target_object_index,
|
| 365 |
+
training_epochs,
|
| 366 |
+
noise_scale,
|
| 367 |
+
seed,
|
| 368 |
+
],
|
| 369 |
+
label="Examples (Click to load)"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
return demo
|
| 373 |
+
|
| 374 |
+
demo = create_demo("flux-dev")
|
| 375 |
+
demo.launch()
|
examples/car.png
ADDED
|
Git LFS Details
|
examples/car_mask.png
ADDED
|
Git LFS Details
|
examples/cup.png
ADDED
|
Git LFS Details
|
examples/cup_mask.png
ADDED
|
Git LFS Details
|
examples/woman.png
ADDED
|
Git LFS Details
|
examples/woman_mask.png
ADDED
|
Git LFS Details
|
flux/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from ._version import version as __version__ # type: ignore
|
| 3 |
+
from ._version import version_tuple
|
| 4 |
+
except ImportError:
|
| 5 |
+
__version__ = "unknown (no version information available)"
|
| 6 |
+
version_tuple = (0, 0, "unknown", "noinfo")
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
PACKAGE = __package__.replace("_", "-")
|
| 11 |
+
PACKAGE_ROOT = Path(__file__).parent
|
flux/__main__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .cli import app
|
| 2 |
+
|
| 3 |
+
if __name__ == "__main__":
|
| 4 |
+
app()
|
flux/_version.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file generated by setuptools-scm
|
| 2 |
+
# don't change, don't track in version control
|
| 3 |
+
|
| 4 |
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
| 5 |
+
|
| 6 |
+
TYPE_CHECKING = False
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
| 12 |
+
else:
|
| 13 |
+
VERSION_TUPLE = object
|
| 14 |
+
|
| 15 |
+
version: str
|
| 16 |
+
__version__: str
|
| 17 |
+
__version_tuple__: VERSION_TUPLE
|
| 18 |
+
version_tuple: VERSION_TUPLE
|
| 19 |
+
|
| 20 |
+
__version__ = version = '0.0.post61+g0274301.d20250318'
|
| 21 |
+
__version_tuple__ = version_tuple = (0, 0, 'g0274301.d20250318')
|
flux/api.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
API_ENDPOINT = "https://api.bfl.ml"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ApiException(Exception):
|
| 13 |
+
def __init__(self, status_code: int, detail: str | list[dict] | None = None):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.detail = detail
|
| 16 |
+
self.status_code = status_code
|
| 17 |
+
|
| 18 |
+
def __str__(self) -> str:
|
| 19 |
+
return self.__repr__()
|
| 20 |
+
|
| 21 |
+
def __repr__(self) -> str:
|
| 22 |
+
if self.detail is None:
|
| 23 |
+
message = None
|
| 24 |
+
elif isinstance(self.detail, str):
|
| 25 |
+
message = self.detail
|
| 26 |
+
else:
|
| 27 |
+
message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
|
| 28 |
+
return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ImageRequest:
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
prompt: str,
|
| 35 |
+
width: int = 1024,
|
| 36 |
+
height: int = 1024,
|
| 37 |
+
name: str = "flux.1-pro",
|
| 38 |
+
num_steps: int = 50,
|
| 39 |
+
prompt_upsampling: bool = False,
|
| 40 |
+
seed: int | None = None,
|
| 41 |
+
validate: bool = True,
|
| 42 |
+
launch: bool = True,
|
| 43 |
+
api_key: str | None = None,
|
| 44 |
+
):
|
| 45 |
+
"""
|
| 46 |
+
Manages an image generation request to the API.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
prompt: Prompt to sample
|
| 50 |
+
width: Width of the image in pixel
|
| 51 |
+
height: Height of the image in pixel
|
| 52 |
+
name: Name of the model
|
| 53 |
+
num_steps: Number of network evaluations
|
| 54 |
+
prompt_upsampling: Use prompt upsampling
|
| 55 |
+
seed: Fix the generation seed
|
| 56 |
+
validate: Run input validation
|
| 57 |
+
launch: Directly launches request
|
| 58 |
+
api_key: Your API key if not provided by the environment
|
| 59 |
+
|
| 60 |
+
Raises:
|
| 61 |
+
ValueError: For invalid input
|
| 62 |
+
ApiException: For errors raised from the API
|
| 63 |
+
"""
|
| 64 |
+
if validate:
|
| 65 |
+
if name not in ["flux.1-pro"]:
|
| 66 |
+
raise ValueError(f"Invalid model {name}")
|
| 67 |
+
elif width % 32 != 0:
|
| 68 |
+
raise ValueError(f"width must be divisible by 32, got {width}")
|
| 69 |
+
elif not (256 <= width <= 1440):
|
| 70 |
+
raise ValueError(f"width must be between 256 and 1440, got {width}")
|
| 71 |
+
elif height % 32 != 0:
|
| 72 |
+
raise ValueError(f"height must be divisible by 32, got {height}")
|
| 73 |
+
elif not (256 <= height <= 1440):
|
| 74 |
+
raise ValueError(f"height must be between 256 and 1440, got {height}")
|
| 75 |
+
elif not (1 <= num_steps <= 50):
|
| 76 |
+
raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
|
| 77 |
+
|
| 78 |
+
self.request_json = {
|
| 79 |
+
"prompt": prompt,
|
| 80 |
+
"width": width,
|
| 81 |
+
"height": height,
|
| 82 |
+
"variant": name,
|
| 83 |
+
"steps": num_steps,
|
| 84 |
+
"prompt_upsampling": prompt_upsampling,
|
| 85 |
+
}
|
| 86 |
+
if seed is not None:
|
| 87 |
+
self.request_json["seed"] = seed
|
| 88 |
+
|
| 89 |
+
self.request_id: str | None = None
|
| 90 |
+
self.result: dict | None = None
|
| 91 |
+
self._image_bytes: bytes | None = None
|
| 92 |
+
self._url: str | None = None
|
| 93 |
+
if api_key is None:
|
| 94 |
+
self.api_key = os.environ.get("BFL_API_KEY")
|
| 95 |
+
else:
|
| 96 |
+
self.api_key = api_key
|
| 97 |
+
|
| 98 |
+
if launch:
|
| 99 |
+
self.request()
|
| 100 |
+
|
| 101 |
+
def request(self):
|
| 102 |
+
"""
|
| 103 |
+
Request to generate the image.
|
| 104 |
+
"""
|
| 105 |
+
if self.request_id is not None:
|
| 106 |
+
return
|
| 107 |
+
response = requests.post(
|
| 108 |
+
f"{API_ENDPOINT}/v1/image",
|
| 109 |
+
headers={
|
| 110 |
+
"accept": "application/json",
|
| 111 |
+
"x-key": self.api_key,
|
| 112 |
+
"Content-Type": "application/json",
|
| 113 |
+
},
|
| 114 |
+
json=self.request_json,
|
| 115 |
+
)
|
| 116 |
+
result = response.json()
|
| 117 |
+
if response.status_code != 200:
|
| 118 |
+
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
|
| 119 |
+
self.request_id = response.json()["id"]
|
| 120 |
+
|
| 121 |
+
def retrieve(self) -> dict:
|
| 122 |
+
"""
|
| 123 |
+
Wait for the generation to finish and retrieve response.
|
| 124 |
+
"""
|
| 125 |
+
if self.request_id is None:
|
| 126 |
+
self.request()
|
| 127 |
+
while self.result is None:
|
| 128 |
+
response = requests.get(
|
| 129 |
+
f"{API_ENDPOINT}/v1/get_result",
|
| 130 |
+
headers={
|
| 131 |
+
"accept": "application/json",
|
| 132 |
+
"x-key": self.api_key,
|
| 133 |
+
},
|
| 134 |
+
params={
|
| 135 |
+
"id": self.request_id,
|
| 136 |
+
},
|
| 137 |
+
)
|
| 138 |
+
result = response.json()
|
| 139 |
+
if "status" not in result:
|
| 140 |
+
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
|
| 141 |
+
elif result["status"] == "Ready":
|
| 142 |
+
self.result = result["result"]
|
| 143 |
+
elif result["status"] == "Pending":
|
| 144 |
+
time.sleep(0.5)
|
| 145 |
+
else:
|
| 146 |
+
raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
|
| 147 |
+
return self.result
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def bytes(self) -> bytes:
|
| 151 |
+
"""
|
| 152 |
+
Generated image as bytes.
|
| 153 |
+
"""
|
| 154 |
+
if self._image_bytes is None:
|
| 155 |
+
response = requests.get(self.url)
|
| 156 |
+
if response.status_code == 200:
|
| 157 |
+
self._image_bytes = response.content
|
| 158 |
+
else:
|
| 159 |
+
raise ApiException(status_code=response.status_code)
|
| 160 |
+
return self._image_bytes
|
| 161 |
+
|
| 162 |
+
@property
|
| 163 |
+
def url(self) -> str:
|
| 164 |
+
"""
|
| 165 |
+
Public url to retrieve the image from
|
| 166 |
+
"""
|
| 167 |
+
if self._url is None:
|
| 168 |
+
result = self.retrieve()
|
| 169 |
+
self._url = result["sample"]
|
| 170 |
+
return self._url
|
| 171 |
+
|
| 172 |
+
@property
|
| 173 |
+
def image(self) -> Image.Image:
|
| 174 |
+
"""
|
| 175 |
+
Load the image as a PIL Image
|
| 176 |
+
"""
|
| 177 |
+
return Image.open(io.BytesIO(self.bytes))
|
| 178 |
+
|
| 179 |
+
def save(self, path: str):
|
| 180 |
+
"""
|
| 181 |
+
Save the generated image to a local path
|
| 182 |
+
"""
|
| 183 |
+
suffix = Path(self.url).suffix
|
| 184 |
+
if not path.endswith(suffix):
|
| 185 |
+
path = path + suffix
|
| 186 |
+
Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
|
| 187 |
+
with open(path, "wb") as file:
|
| 188 |
+
file.write(self.bytes)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
from fire import Fire
|
| 193 |
+
|
| 194 |
+
Fire(ImageRequest)
|
flux/math.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 7 |
+
q, k = apply_rope(q, k, pe)
|
| 8 |
+
|
| 9 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 10 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
| 11 |
+
|
| 12 |
+
return x
|
| 13 |
+
|
| 14 |
+
def attention_with_attnmap(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 15 |
+
q, k = apply_rope(q, k, pe)
|
| 16 |
+
|
| 17 |
+
x= torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 18 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
| 19 |
+
|
| 20 |
+
# get attn map
|
| 21 |
+
d_k = q.shape[-1] # head_dim (D)
|
| 22 |
+
attn_map = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5) # [B, H, L, L]
|
| 23 |
+
return x, attn_map
|
| 24 |
+
|
| 25 |
+
def attention_with_attnmap_injection(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attnmap_idxs, old_attnmaps) -> Tensor:
|
| 26 |
+
q, k = apply_rope(q, k, pe)
|
| 27 |
+
|
| 28 |
+
# original attn
|
| 29 |
+
# x= torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 30 |
+
# x = rearrange(x, "B H L D -> B L (H D)")
|
| 31 |
+
|
| 32 |
+
# get attn map
|
| 33 |
+
d_k = q.shape[-1] # head_dim (D)
|
| 34 |
+
attn_map = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5) # [B, H, L, L]
|
| 35 |
+
attn_map = torch.softmax(attn_map, dim=-1)
|
| 36 |
+
# inject attn map
|
| 37 |
+
for idx,old_attnmap in zip(attnmap_idxs,old_attnmaps):
|
| 38 |
+
attn_map[:,:,512:,idx] = old_attnmap
|
| 39 |
+
x = attn_map @ v
|
| 40 |
+
return x, attn_map
|
| 41 |
+
|
| 42 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
| 43 |
+
assert dim % 2 == 0
|
| 44 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
| 45 |
+
omega = 1.0 / (theta**scale)
|
| 46 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
| 47 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
| 48 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
| 49 |
+
return out.float()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
| 53 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 54 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 55 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 56 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 57 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
flux/model_lore.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor, nn
|
| 5 |
+
|
| 6 |
+
from flux.modules.layers_lore import (DoubleStreamBlock, EmbedND, LastLayer,
|
| 7 |
+
MLPEmbedder, SingleStreamBlock,
|
| 8 |
+
timestep_embedding)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class FluxParams:
|
| 13 |
+
in_channels: int
|
| 14 |
+
vec_in_dim: int
|
| 15 |
+
context_in_dim: int
|
| 16 |
+
hidden_size: int
|
| 17 |
+
mlp_ratio: float
|
| 18 |
+
num_heads: int
|
| 19 |
+
depth: int
|
| 20 |
+
depth_single_blocks: int
|
| 21 |
+
axes_dim: list[int]
|
| 22 |
+
theta: int
|
| 23 |
+
qkv_bias: bool
|
| 24 |
+
guidance_embed: bool
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Flux(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
Transformer model for flow matching on sequences.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, params: FluxParams):
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
self.params = params
|
| 36 |
+
self.in_channels = params.in_channels
|
| 37 |
+
self.out_channels = self.in_channels
|
| 38 |
+
if params.hidden_size % params.num_heads != 0:
|
| 39 |
+
raise ValueError(
|
| 40 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
| 41 |
+
)
|
| 42 |
+
pe_dim = params.hidden_size // params.num_heads
|
| 43 |
+
if sum(params.axes_dim) != pe_dim:
|
| 44 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
| 45 |
+
self.hidden_size = params.hidden_size
|
| 46 |
+
self.num_heads = params.num_heads
|
| 47 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
| 48 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
| 49 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 50 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
| 51 |
+
self.guidance_in = (
|
| 52 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
| 53 |
+
)
|
| 54 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
| 55 |
+
|
| 56 |
+
self.double_blocks = nn.ModuleList(
|
| 57 |
+
[
|
| 58 |
+
DoubleStreamBlock(
|
| 59 |
+
self.hidden_size,
|
| 60 |
+
self.num_heads,
|
| 61 |
+
mlp_ratio=params.mlp_ratio,
|
| 62 |
+
qkv_bias=params.qkv_bias,
|
| 63 |
+
)
|
| 64 |
+
for _ in range(params.depth)
|
| 65 |
+
]
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.single_blocks = nn.ModuleList(
|
| 69 |
+
[
|
| 70 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
| 71 |
+
for _ in range(params.depth_single_blocks)
|
| 72 |
+
]
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
| 76 |
+
|
| 77 |
+
def forward(
|
| 78 |
+
self,
|
| 79 |
+
img: Tensor,
|
| 80 |
+
img_ids: Tensor,
|
| 81 |
+
txt: Tensor,
|
| 82 |
+
txt_ids: Tensor,
|
| 83 |
+
timesteps: Tensor,
|
| 84 |
+
y: Tensor,
|
| 85 |
+
guidance: Tensor | None = None,
|
| 86 |
+
info = None,
|
| 87 |
+
) -> Tensor:
|
| 88 |
+
if img.ndim != 3 or txt.ndim != 3:
|
| 89 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
| 90 |
+
|
| 91 |
+
# running on sequences img
|
| 92 |
+
img = self.img_in(img)
|
| 93 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
| 94 |
+
if self.params.guidance_embed:
|
| 95 |
+
if guidance is None:
|
| 96 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
| 97 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
| 98 |
+
vec = vec + self.vector_in(y)
|
| 99 |
+
txt = self.txt_in(txt)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
| 103 |
+
pe = self.pe_embedder(ids)
|
| 104 |
+
|
| 105 |
+
attn_maps = []
|
| 106 |
+
|
| 107 |
+
for block in self.double_blocks:
|
| 108 |
+
img, txt, attn_map = block(img=img, txt=txt, vec=vec, pe=pe, info=info)
|
| 109 |
+
attn_maps.append(attn_map)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
cnt = 0
|
| 113 |
+
img = torch.cat((txt, img), 1)
|
| 114 |
+
info['type'] = 'single'
|
| 115 |
+
for block in self.single_blocks:
|
| 116 |
+
info['id'] = cnt
|
| 117 |
+
img, info, attn_map = block(img, vec=vec, pe=pe, info=info)
|
| 118 |
+
attn_maps.append(attn_map)
|
| 119 |
+
cnt += 1
|
| 120 |
+
attn_maps = torch.stack(attn_maps)
|
| 121 |
+
img = img[:, txt.shape[1] :, ...]
|
| 122 |
+
|
| 123 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) # 1, N, 64
|
| 124 |
+
return img, info, attn_maps
|
flux/modules/autoencoder.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class AutoEncoderParams:
|
| 10 |
+
resolution: int
|
| 11 |
+
in_channels: int
|
| 12 |
+
ch: int
|
| 13 |
+
out_ch: int
|
| 14 |
+
ch_mult: list[int]
|
| 15 |
+
num_res_blocks: int
|
| 16 |
+
z_channels: int
|
| 17 |
+
scale_factor: float
|
| 18 |
+
shift_factor: float
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def swish(x: Tensor) -> Tensor:
|
| 22 |
+
return x * torch.sigmoid(x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AttnBlock(nn.Module):
|
| 26 |
+
def __init__(self, in_channels: int):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.in_channels = in_channels
|
| 29 |
+
|
| 30 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 31 |
+
|
| 32 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 33 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 34 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 35 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 36 |
+
|
| 37 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 38 |
+
h_ = self.norm(h_)
|
| 39 |
+
q = self.q(h_)
|
| 40 |
+
k = self.k(h_)
|
| 41 |
+
v = self.v(h_)
|
| 42 |
+
|
| 43 |
+
b, c, h, w = q.shape
|
| 44 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
| 45 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
| 46 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
| 47 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 48 |
+
|
| 49 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
| 50 |
+
|
| 51 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 52 |
+
return x + self.proj_out(self.attention(x))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ResnetBlock(nn.Module):
|
| 56 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.in_channels = in_channels
|
| 59 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 60 |
+
self.out_channels = out_channels
|
| 61 |
+
|
| 62 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 63 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 64 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
| 65 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 66 |
+
if self.in_channels != self.out_channels:
|
| 67 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
h = x
|
| 71 |
+
h = self.norm1(h)
|
| 72 |
+
h = swish(h)
|
| 73 |
+
h = self.conv1(h)
|
| 74 |
+
|
| 75 |
+
h = self.norm2(h)
|
| 76 |
+
h = swish(h)
|
| 77 |
+
h = self.conv2(h)
|
| 78 |
+
|
| 79 |
+
if self.in_channels != self.out_channels:
|
| 80 |
+
x = self.nin_shortcut(x)
|
| 81 |
+
|
| 82 |
+
return x + h
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Downsample(nn.Module):
|
| 86 |
+
def __init__(self, in_channels: int):
|
| 87 |
+
super().__init__()
|
| 88 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 89 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 90 |
+
|
| 91 |
+
def forward(self, x: Tensor):
|
| 92 |
+
pad = (0, 1, 0, 1)
|
| 93 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
| 94 |
+
x = self.conv(x)
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Upsample(nn.Module):
|
| 99 |
+
def __init__(self, in_channels: int):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 102 |
+
|
| 103 |
+
def forward(self, x: Tensor):
|
| 104 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 105 |
+
x = self.conv(x)
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class Encoder(nn.Module):
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
resolution: int,
|
| 113 |
+
in_channels: int,
|
| 114 |
+
ch: int,
|
| 115 |
+
ch_mult: list[int],
|
| 116 |
+
num_res_blocks: int,
|
| 117 |
+
z_channels: int,
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.ch = ch
|
| 121 |
+
self.num_resolutions = len(ch_mult)
|
| 122 |
+
self.num_res_blocks = num_res_blocks
|
| 123 |
+
self.resolution = resolution
|
| 124 |
+
self.in_channels = in_channels
|
| 125 |
+
# downsampling
|
| 126 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 127 |
+
|
| 128 |
+
curr_res = resolution
|
| 129 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 130 |
+
self.in_ch_mult = in_ch_mult
|
| 131 |
+
self.down = nn.ModuleList()
|
| 132 |
+
block_in = self.ch
|
| 133 |
+
for i_level in range(self.num_resolutions):
|
| 134 |
+
block = nn.ModuleList()
|
| 135 |
+
attn = nn.ModuleList()
|
| 136 |
+
block_in = ch * in_ch_mult[i_level]
|
| 137 |
+
block_out = ch * ch_mult[i_level]
|
| 138 |
+
for _ in range(self.num_res_blocks):
|
| 139 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 140 |
+
block_in = block_out
|
| 141 |
+
down = nn.Module()
|
| 142 |
+
down.block = block
|
| 143 |
+
down.attn = attn
|
| 144 |
+
if i_level != self.num_resolutions - 1:
|
| 145 |
+
down.downsample = Downsample(block_in)
|
| 146 |
+
curr_res = curr_res // 2
|
| 147 |
+
self.down.append(down)
|
| 148 |
+
|
| 149 |
+
# middle
|
| 150 |
+
self.mid = nn.Module()
|
| 151 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 152 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 153 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 154 |
+
|
| 155 |
+
# end
|
| 156 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 157 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
| 158 |
+
|
| 159 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 160 |
+
# downsampling
|
| 161 |
+
hs = [self.conv_in(x)]
|
| 162 |
+
for i_level in range(self.num_resolutions):
|
| 163 |
+
for i_block in range(self.num_res_blocks):
|
| 164 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
| 165 |
+
if len(self.down[i_level].attn) > 0:
|
| 166 |
+
h = self.down[i_level].attn[i_block](h)
|
| 167 |
+
hs.append(h)
|
| 168 |
+
if i_level != self.num_resolutions - 1:
|
| 169 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 170 |
+
|
| 171 |
+
# middle
|
| 172 |
+
h = hs[-1]
|
| 173 |
+
h = self.mid.block_1(h)
|
| 174 |
+
h = self.mid.attn_1(h)
|
| 175 |
+
h = self.mid.block_2(h)
|
| 176 |
+
# end
|
| 177 |
+
h = self.norm_out(h)
|
| 178 |
+
h = swish(h)
|
| 179 |
+
h = self.conv_out(h)
|
| 180 |
+
return h
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class Decoder(nn.Module):
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
ch: int,
|
| 187 |
+
out_ch: int,
|
| 188 |
+
ch_mult: list[int],
|
| 189 |
+
num_res_blocks: int,
|
| 190 |
+
in_channels: int,
|
| 191 |
+
resolution: int,
|
| 192 |
+
z_channels: int,
|
| 193 |
+
):
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.ch = ch
|
| 196 |
+
self.num_resolutions = len(ch_mult)
|
| 197 |
+
self.num_res_blocks = num_res_blocks
|
| 198 |
+
self.resolution = resolution
|
| 199 |
+
self.in_channels = in_channels
|
| 200 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
| 201 |
+
|
| 202 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 203 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 204 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 205 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 206 |
+
|
| 207 |
+
# z to block_in
|
| 208 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 209 |
+
|
| 210 |
+
# middle
|
| 211 |
+
self.mid = nn.Module()
|
| 212 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 213 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 214 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 215 |
+
|
| 216 |
+
# upsampling
|
| 217 |
+
self.up = nn.ModuleList()
|
| 218 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 219 |
+
block = nn.ModuleList()
|
| 220 |
+
attn = nn.ModuleList()
|
| 221 |
+
block_out = ch * ch_mult[i_level]
|
| 222 |
+
for _ in range(self.num_res_blocks + 1):
|
| 223 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 224 |
+
block_in = block_out
|
| 225 |
+
up = nn.Module()
|
| 226 |
+
up.block = block
|
| 227 |
+
up.attn = attn
|
| 228 |
+
if i_level != 0:
|
| 229 |
+
up.upsample = Upsample(block_in)
|
| 230 |
+
curr_res = curr_res * 2
|
| 231 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 232 |
+
|
| 233 |
+
# end
|
| 234 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 235 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 236 |
+
|
| 237 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 238 |
+
# z to block_in
|
| 239 |
+
h = self.conv_in(z)
|
| 240 |
+
|
| 241 |
+
# middle
|
| 242 |
+
h = self.mid.block_1(h)
|
| 243 |
+
h = self.mid.attn_1(h)
|
| 244 |
+
h = self.mid.block_2(h)
|
| 245 |
+
|
| 246 |
+
# upsampling
|
| 247 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 248 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 249 |
+
h = self.up[i_level].block[i_block](h)
|
| 250 |
+
if len(self.up[i_level].attn) > 0:
|
| 251 |
+
h = self.up[i_level].attn[i_block](h)
|
| 252 |
+
if i_level != 0:
|
| 253 |
+
h = self.up[i_level].upsample(h)
|
| 254 |
+
|
| 255 |
+
# end
|
| 256 |
+
h = self.norm_out(h)
|
| 257 |
+
h = swish(h)
|
| 258 |
+
h = self.conv_out(h)
|
| 259 |
+
return h
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class DiagonalGaussian(nn.Module):
|
| 263 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.sample = sample
|
| 266 |
+
self.chunk_dim = chunk_dim
|
| 267 |
+
|
| 268 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 269 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
| 270 |
+
# import pdb;pdb.set_trace()
|
| 271 |
+
if self.sample:
|
| 272 |
+
std = torch.exp(0.5 * logvar)
|
| 273 |
+
return mean #+ std * torch.randn_like(mean)
|
| 274 |
+
else:
|
| 275 |
+
return mean
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class AutoEncoder(nn.Module):
|
| 279 |
+
def __init__(self, params: AutoEncoderParams):
|
| 280 |
+
super().__init__()
|
| 281 |
+
self.encoder = Encoder(
|
| 282 |
+
resolution=params.resolution,
|
| 283 |
+
in_channels=params.in_channels,
|
| 284 |
+
ch=params.ch,
|
| 285 |
+
ch_mult=params.ch_mult,
|
| 286 |
+
num_res_blocks=params.num_res_blocks,
|
| 287 |
+
z_channels=params.z_channels,
|
| 288 |
+
)
|
| 289 |
+
self.decoder = Decoder(
|
| 290 |
+
resolution=params.resolution,
|
| 291 |
+
in_channels=params.in_channels,
|
| 292 |
+
ch=params.ch,
|
| 293 |
+
out_ch=params.out_ch,
|
| 294 |
+
ch_mult=params.ch_mult,
|
| 295 |
+
num_res_blocks=params.num_res_blocks,
|
| 296 |
+
z_channels=params.z_channels,
|
| 297 |
+
)
|
| 298 |
+
self.reg = DiagonalGaussian()
|
| 299 |
+
|
| 300 |
+
self.scale_factor = params.scale_factor
|
| 301 |
+
self.shift_factor = params.shift_factor
|
| 302 |
+
|
| 303 |
+
def encode(self, x: Tensor) -> Tensor:
|
| 304 |
+
z = self.reg(self.encoder(x))
|
| 305 |
+
z = self.scale_factor * (z - self.shift_factor)
|
| 306 |
+
return z
|
| 307 |
+
|
| 308 |
+
def decode(self, z: Tensor) -> Tensor:
|
| 309 |
+
z = z / self.scale_factor + self.shift_factor
|
| 310 |
+
return self.decoder(z)
|
| 311 |
+
|
| 312 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 313 |
+
return self.decode(self.encode(x))
|
flux/modules/conditioner_lore.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import Tensor, nn
|
| 2 |
+
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
|
| 3 |
+
T5Tokenizer)
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class HFEmbedder(nn.Module):
|
| 7 |
+
def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.is_clip = is_clip
|
| 10 |
+
self.max_length = max_length
|
| 11 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
| 12 |
+
|
| 13 |
+
if self.is_clip:
|
| 14 |
+
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
| 15 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
| 16 |
+
else:
|
| 17 |
+
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
|
| 18 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
|
| 19 |
+
|
| 20 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
| 21 |
+
|
| 22 |
+
def forward(self, text: list[str]) -> Tensor:
|
| 23 |
+
batch_encoding = self.tokenizer(
|
| 24 |
+
text,
|
| 25 |
+
truncation=True,
|
| 26 |
+
max_length=self.max_length,
|
| 27 |
+
return_length=False,
|
| 28 |
+
return_overflowing_tokens=False,
|
| 29 |
+
padding="max_length",
|
| 30 |
+
return_tensors="pt",
|
| 31 |
+
)
|
| 32 |
+
if not self.is_clip:
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
outputs = self.hf_module(
|
| 36 |
+
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
| 37 |
+
attention_mask=None,
|
| 38 |
+
output_hidden_states=False,
|
| 39 |
+
)
|
| 40 |
+
return outputs[self.output_key]
|
| 41 |
+
|
| 42 |
+
def forward_length(self, text: list[str]) -> Tensor:
|
| 43 |
+
batch_encoding = self.tokenizer(
|
| 44 |
+
text,
|
| 45 |
+
truncation=True,
|
| 46 |
+
max_length=self.max_length,
|
| 47 |
+
return_length=True,
|
| 48 |
+
return_overflowing_tokens=False,
|
| 49 |
+
padding="max_length",
|
| 50 |
+
return_tensors="pt",
|
| 51 |
+
)
|
| 52 |
+
if not self.is_clip:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
outputs = self.hf_module(
|
| 56 |
+
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
| 57 |
+
attention_mask=None,
|
| 58 |
+
output_hidden_states=False,
|
| 59 |
+
)
|
| 60 |
+
# -1 to delete the end token
|
| 61 |
+
return outputs[self.output_key],batch_encoding['length']-1
|
| 62 |
+
|
| 63 |
+
def get_word_embed(self, text: list[str]) -> Tensor:
|
| 64 |
+
batch_encoding = self.tokenizer(
|
| 65 |
+
text,
|
| 66 |
+
truncation=True,
|
| 67 |
+
max_length=16,
|
| 68 |
+
return_length=True,
|
| 69 |
+
return_overflowing_tokens=False,
|
| 70 |
+
padding="max_length",
|
| 71 |
+
return_tensors="pt",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
input_ids = batch_encoding["input_ids"].to(self.hf_module.device)
|
| 75 |
+
attention_mask = batch_encoding["attention_mask"].to(self.hf_module.device)
|
| 76 |
+
|
| 77 |
+
outputs = self.hf_module(
|
| 78 |
+
input_ids=input_ids,
|
| 79 |
+
attention_mask=attention_mask,
|
| 80 |
+
output_hidden_states=False,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
token_embeddings = outputs[self.output_key] # [B, T, D]
|
| 84 |
+
mask = attention_mask.unsqueeze(-1).float() # [B, T, 1]
|
| 85 |
+
summed = (token_embeddings * mask).sum(dim=1) # [B, D]
|
| 86 |
+
counts = mask.sum(dim=1).clamp(min=1e-6)
|
| 87 |
+
mean_pooled = summed / counts # [B, D]
|
| 88 |
+
|
| 89 |
+
return mean_pooled
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_text_embeddings_with_diff(self, src_text: str, tgt_text: str, replacements: list[tuple[str, str, int, int]], show_tokens=False, return_embeds=False):
|
| 93 |
+
batch_encoding = self.tokenizer(
|
| 94 |
+
[src_text, tgt_text],
|
| 95 |
+
truncation=True,
|
| 96 |
+
max_length=self.max_length,
|
| 97 |
+
return_tensors="pt",
|
| 98 |
+
padding="max_length",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
src_ids, tgt_ids = batch_encoding["input_ids"]
|
| 102 |
+
|
| 103 |
+
src_tokens = self.tokenizer.tokenize(src_text)
|
| 104 |
+
tgt_tokens = self.tokenizer.tokenize(tgt_text)
|
| 105 |
+
if show_tokens:
|
| 106 |
+
print("src tokens", src_tokens)
|
| 107 |
+
print("tgt tokens", tgt_tokens)
|
| 108 |
+
|
| 109 |
+
src_dif_ids = []
|
| 110 |
+
tgt_dif_ids = []
|
| 111 |
+
def find_mappings(tokens,words,start_idx):
|
| 112 |
+
if (words is None) or start_idx<0: # some samples do not need this
|
| 113 |
+
return [-1]
|
| 114 |
+
res = []
|
| 115 |
+
flag = 0
|
| 116 |
+
for i in range(start_idx,len(tokens)):
|
| 117 |
+
this_token = tokens[i].strip('▁')
|
| 118 |
+
if this_token == "":
|
| 119 |
+
continue
|
| 120 |
+
if words.startswith(this_token):
|
| 121 |
+
res.append(i)
|
| 122 |
+
flag = 1
|
| 123 |
+
if words.endswith(this_token):
|
| 124 |
+
break
|
| 125 |
+
else:
|
| 126 |
+
continue
|
| 127 |
+
if flag and words.endswith(this_token):
|
| 128 |
+
res.append(i)
|
| 129 |
+
break
|
| 130 |
+
if flag:
|
| 131 |
+
res.append(i)
|
| 132 |
+
return res
|
| 133 |
+
|
| 134 |
+
for src_words, tgt_words, src_index, tgt_index in replacements:
|
| 135 |
+
if src_words:
|
| 136 |
+
src_dif_ids.append(find_mappings(src_tokens,src_words,src_index))
|
| 137 |
+
else:
|
| 138 |
+
src_dif_ids.append([-1])
|
| 139 |
+
if tgt_words:
|
| 140 |
+
tgt_dif_ids.append(find_mappings(tgt_tokens,tgt_words,tgt_index))
|
| 141 |
+
else:
|
| 142 |
+
tgt_dif_ids.append([-1])
|
| 143 |
+
|
| 144 |
+
if return_embeds:
|
| 145 |
+
outputs = self.hf_module(
|
| 146 |
+
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
| 147 |
+
attention_mask=None,
|
| 148 |
+
output_hidden_states=False,
|
| 149 |
+
)
|
| 150 |
+
embeddings = outputs[self.output_key]
|
| 151 |
+
else:
|
| 152 |
+
embeddings = (None,None)
|
| 153 |
+
return embeddings[0], embeddings[1], src_dif_ids, tgt_dif_ids
|
| 154 |
+
|
| 155 |
+
|
flux/modules/layers_lore.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
|
| 8 |
+
from flux.math import attention, rope, attention_with_attnmap
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
class EmbedND(nn.Module):
|
| 13 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.dim = dim
|
| 16 |
+
self.theta = theta
|
| 17 |
+
self.axes_dim = axes_dim
|
| 18 |
+
|
| 19 |
+
def forward(self, ids: Tensor) -> Tensor:
|
| 20 |
+
n_axes = ids.shape[-1]
|
| 21 |
+
emb = torch.cat(
|
| 22 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
| 23 |
+
dim=-3,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
return emb.unsqueeze(1)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
| 30 |
+
"""
|
| 31 |
+
Create sinusoidal timestep embeddings.
|
| 32 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 33 |
+
These may be fractional.
|
| 34 |
+
:param dim: the dimension of the output.
|
| 35 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 36 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 37 |
+
"""
|
| 38 |
+
t = time_factor * t
|
| 39 |
+
half = dim // 2
|
| 40 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
| 41 |
+
t.device
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
args = t[:, None].float() * freqs[None]
|
| 45 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 46 |
+
if dim % 2:
|
| 47 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 48 |
+
if torch.is_floating_point(t):
|
| 49 |
+
embedding = embedding.to(t)
|
| 50 |
+
return embedding
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class MLPEmbedder(nn.Module):
|
| 54 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
| 57 |
+
self.silu = nn.SiLU()
|
| 58 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 59 |
+
|
| 60 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 61 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class RMSNorm(torch.nn.Module):
|
| 65 |
+
def __init__(self, dim: int):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
| 68 |
+
|
| 69 |
+
def forward(self, x: Tensor):
|
| 70 |
+
x_dtype = x.dtype
|
| 71 |
+
x = x.float()
|
| 72 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
| 73 |
+
return (x * rrms).to(dtype=x_dtype) * self.scale
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class QKNorm(torch.nn.Module):
|
| 77 |
+
def __init__(self, dim: int):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.query_norm = RMSNorm(dim)
|
| 80 |
+
self.key_norm = RMSNorm(dim)
|
| 81 |
+
|
| 82 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
| 83 |
+
q = self.query_norm(q)
|
| 84 |
+
k = self.key_norm(k)
|
| 85 |
+
return q.to(v), k.to(v)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class SelfAttention(nn.Module):
|
| 89 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.num_heads = num_heads
|
| 92 |
+
head_dim = dim // num_heads
|
| 93 |
+
|
| 94 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 95 |
+
self.norm = QKNorm(head_dim)
|
| 96 |
+
self.proj = nn.Linear(dim, dim)
|
| 97 |
+
|
| 98 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
| 99 |
+
qkv = self.qkv(x)
|
| 100 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 101 |
+
q, k = self.norm(q, k, v)
|
| 102 |
+
x = attention(q, k, v, pe=pe)
|
| 103 |
+
x = self.proj(x)
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@dataclass
|
| 108 |
+
class ModulationOut:
|
| 109 |
+
shift: Tensor
|
| 110 |
+
scale: Tensor
|
| 111 |
+
gate: Tensor
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Modulation(nn.Module):
|
| 115 |
+
def __init__(self, dim: int, double: bool):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.is_double = double
|
| 118 |
+
self.multiplier = 6 if double else 3
|
| 119 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
| 120 |
+
|
| 121 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
| 122 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
| 123 |
+
|
| 124 |
+
return (
|
| 125 |
+
ModulationOut(*out[:3]),
|
| 126 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class DoubleStreamBlock(nn.Module):
|
| 131 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
| 132 |
+
super().__init__()
|
| 133 |
+
|
| 134 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 135 |
+
self.num_heads = num_heads
|
| 136 |
+
self.hidden_size = hidden_size
|
| 137 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
| 138 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 139 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
| 140 |
+
|
| 141 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 142 |
+
self.img_mlp = nn.Sequential(
|
| 143 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 144 |
+
nn.GELU(approximate="tanh"),
|
| 145 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
| 149 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 150 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
| 151 |
+
|
| 152 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 153 |
+
self.txt_mlp = nn.Sequential(
|
| 154 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 155 |
+
nn.GELU(approximate="tanh"),
|
| 156 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, info) -> tuple[Tensor, Tensor]:
|
| 160 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
| 161 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
| 162 |
+
|
| 163 |
+
# prepare image for attention
|
| 164 |
+
img_modulated = self.img_norm1(img)
|
| 165 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
| 166 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
| 167 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 168 |
+
|
| 169 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
| 170 |
+
|
| 171 |
+
# prepare txt for attention
|
| 172 |
+
txt_modulated = self.txt_norm1(txt)
|
| 173 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
| 174 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
| 175 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 176 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 177 |
+
|
| 178 |
+
# run actual attention
|
| 179 |
+
q = torch.cat((txt_q, img_q), dim=2) #[8, 24, 512, 128] + [8, 24, 900, 128] -> [8, 24, 1412, 128]
|
| 180 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
| 181 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
| 182 |
+
attn,attn_map = attention_with_attnmap(q, k, v, pe=pe)
|
| 183 |
+
attn_map = attn_map[:, :, txt.shape[1]:, :txt.shape[1]] # text to image attn map
|
| 184 |
+
if 'txt_token_l' in info:
|
| 185 |
+
# drop all paddings
|
| 186 |
+
attn_map = attn_map[:,:,:,:info['txt_token_l']]
|
| 187 |
+
attn_map = torch.nn.functional.softmax(attn_map, dim=-1) # softmax
|
| 188 |
+
attn_map = attn_map.mean(dim=1) # avg all head(24 head)
|
| 189 |
+
|
| 190 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
| 191 |
+
|
| 192 |
+
# calculate the img bloks
|
| 193 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
| 194 |
+
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
| 195 |
+
|
| 196 |
+
# calculate the txt bloks
|
| 197 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
| 198 |
+
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
| 199 |
+
return img, txt, attn_map
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class SingleStreamBlock(nn.Module):
|
| 203 |
+
"""
|
| 204 |
+
A DiT block with parallel linear layers as described in
|
| 205 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
hidden_size: int,
|
| 211 |
+
num_heads: int,
|
| 212 |
+
mlp_ratio: float = 4.0,
|
| 213 |
+
qk_scale: float | None = None,
|
| 214 |
+
):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.hidden_dim = hidden_size
|
| 217 |
+
self.num_heads = num_heads
|
| 218 |
+
head_dim = hidden_size // num_heads
|
| 219 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 220 |
+
|
| 221 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 222 |
+
# qkv and mlp_in
|
| 223 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
| 224 |
+
# proj and mlp_out
|
| 225 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
| 226 |
+
|
| 227 |
+
self.norm = QKNorm(head_dim)
|
| 228 |
+
|
| 229 |
+
self.hidden_size = hidden_size
|
| 230 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 231 |
+
|
| 232 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
| 233 |
+
self.modulation = Modulation(hidden_size, double=False)
|
| 234 |
+
|
| 235 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, info) -> Tensor:
|
| 236 |
+
mod, _ = self.modulation(vec)
|
| 237 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
| 238 |
+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
| 239 |
+
|
| 240 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 241 |
+
q, k = self.norm(q, k, v)
|
| 242 |
+
|
| 243 |
+
# Note: If the memory of your device is not enough, you may consider uncomment the following code.
|
| 244 |
+
# if info['inject'] and info['id'] > 19:
|
| 245 |
+
# store_path = os.path.join(info['feature_path'], str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'V' + '.pth')
|
| 246 |
+
# if info['inverse']:
|
| 247 |
+
# torch.save(v, store_path)
|
| 248 |
+
# if not info['inverse']:
|
| 249 |
+
# v = torch.load(store_path, weights_only=True)
|
| 250 |
+
|
| 251 |
+
# Save the features in the memory # ori: 19
|
| 252 |
+
if info['inject'] and info['id'] > 19:
|
| 253 |
+
if 'ref' not in info:
|
| 254 |
+
info['ref'] = False
|
| 255 |
+
feature_name = str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + str(info['ref']) + '_' + 'V'
|
| 256 |
+
if info['inverse']:
|
| 257 |
+
info['feature'][feature_name] = v.cpu()
|
| 258 |
+
else:
|
| 259 |
+
# v injection with mask
|
| 260 |
+
# 0: original RF-Edit
|
| 261 |
+
# 1: new_v_text + old_v_image
|
| 262 |
+
# 2: new_v*mask + old_v*(1-mask)
|
| 263 |
+
if info['change_v'] == 0:
|
| 264 |
+
v = info['feature'][feature_name].cuda()
|
| 265 |
+
elif info['change_v'] == 1:
|
| 266 |
+
old_v = info['feature'][feature_name].cuda()
|
| 267 |
+
v = torch.cat([v[:, :, :512, :], old_v[:, :, 512:, :]], dim=2)
|
| 268 |
+
elif info['change_v'] == 2:
|
| 269 |
+
old_v = info['feature'][feature_name].cuda()
|
| 270 |
+
v = v * info['v_mask'] + old_v * (1 - info['v_mask'])
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# compute attention
|
| 275 |
+
attn,attn_map = attention_with_attnmap(q, k, v, pe=pe)
|
| 276 |
+
attn_map = attn_map[:, :, 512:, :512] # text to image attn map
|
| 277 |
+
if 'txt_token_l' in info:
|
| 278 |
+
# drop all paddings
|
| 279 |
+
attn_map = attn_map[:,:,:,:info['txt_token_l']]
|
| 280 |
+
attn_map = torch.nn.functional.softmax(attn_map, dim=-1) # softmax
|
| 281 |
+
attn_map = attn_map.mean(dim=1) # avg all head(24 head)
|
| 282 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 283 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
| 284 |
+
return x + mod.gate * output, info, attn_map
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class LastLayer(nn.Module):
|
| 288 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 289 |
+
super().__init__()
|
| 290 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 291 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 292 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
| 293 |
+
|
| 294 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
| 295 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
| 296 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
| 297 |
+
x = self.linear(x)
|
| 298 |
+
return x
|
flux/sampling_lore.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
import copy
|
| 4 |
+
from typing import Callable
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from einops import rearrange, repeat
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
|
| 16 |
+
from .model_lore import Flux
|
| 17 |
+
from .modules.conditioner_lore import HFEmbedder
|
| 18 |
+
|
| 19 |
+
def prepare_tokens(t5, source_prompt, target_prompt, replacements,show_tokens=False):
|
| 20 |
+
_, _, src_dif_ids, tgt_dif_ids=t5.get_text_embeddings_with_diff(source_prompt,target_prompt,replacements,show_tokens=show_tokens)
|
| 21 |
+
return src_dif_ids,tgt_dif_ids
|
| 22 |
+
|
| 23 |
+
transform = transforms.ToTensor()
|
| 24 |
+
|
| 25 |
+
def get_mask_one_tensor(mask_dirs,width,height,device):
|
| 26 |
+
res = []
|
| 27 |
+
for mask_dir in mask_dirs:
|
| 28 |
+
mask_image = Image.open(mask_dir).convert('L')
|
| 29 |
+
# resize
|
| 30 |
+
mask_image = mask_image.resize((math.ceil(height/16), math.ceil(width/16)), Image.Resampling.LANCZOS)
|
| 31 |
+
mask_tensor = transform(mask_image)
|
| 32 |
+
mask_tensor = mask_tensor.squeeze(0)
|
| 33 |
+
# to one dim
|
| 34 |
+
mask_tensor = mask_tensor.flatten()
|
| 35 |
+
mask_tensor = mask_tensor.to(device)
|
| 36 |
+
res.append(mask_tensor)
|
| 37 |
+
res = sum(res)
|
| 38 |
+
res = res.view(1, 1, -1, 1)
|
| 39 |
+
res = res.to(torch.bfloat16)
|
| 40 |
+
return res
|
| 41 |
+
|
| 42 |
+
def get_v_mask(mask_dirs,width,height,device,txt_length=512):
|
| 43 |
+
res = []
|
| 44 |
+
for mask_dir in mask_dirs:
|
| 45 |
+
mask_image = Image.open(mask_dir).convert('L')
|
| 46 |
+
# resize
|
| 47 |
+
mask_image = mask_image.resize((math.ceil(height/16), math.ceil(width/16)), Image.Resampling.LANCZOS)
|
| 48 |
+
mask_tensor = transform(mask_image)
|
| 49 |
+
mask_tensor = mask_tensor.squeeze(0)
|
| 50 |
+
# to one dim
|
| 51 |
+
mask_tensor = mask_tensor.flatten()
|
| 52 |
+
mask_tensor = mask_tensor.to(device)
|
| 53 |
+
res.append(mask_tensor)
|
| 54 |
+
res = sum(res)
|
| 55 |
+
res = torch.cat([torch.ones(txt_length).to(device),res])
|
| 56 |
+
res = res.view(1, 1, -1, 1)
|
| 57 |
+
res = res.to(torch.bfloat16)
|
| 58 |
+
return res
|
| 59 |
+
|
| 60 |
+
def add_masked_noise_to_z(z,mask,width,height,seed=42,noise_scale=0.1):
|
| 61 |
+
if noise_scale == 0:
|
| 62 |
+
return z
|
| 63 |
+
noise = torch.randn(z.shape,device=z.device,dtype=z.dtype,generator=torch.Generator(device=z.device).manual_seed(seed))
|
| 64 |
+
if noise_scale > 10:
|
| 65 |
+
return noise
|
| 66 |
+
# how to change z?
|
| 67 |
+
z = z*(1-mask[0])+noise_scale*noise*mask[0]+(1-noise_scale)*z*mask[0]
|
| 68 |
+
return z
|
| 69 |
+
|
| 70 |
+
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
| 71 |
+
bs, c, h, w = img.shape
|
| 72 |
+
if bs == 1 and not isinstance(prompt, str):
|
| 73 |
+
bs = len(prompt)
|
| 74 |
+
|
| 75 |
+
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
| 76 |
+
if img.shape[0] == 1 and bs > 1:
|
| 77 |
+
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
| 78 |
+
|
| 79 |
+
img_ids = torch.zeros(h // 2, w // 2, 3)
|
| 80 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
| 81 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
| 82 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
| 83 |
+
|
| 84 |
+
if isinstance(prompt, str):
|
| 85 |
+
prompt = [prompt]
|
| 86 |
+
txt = t5(prompt)
|
| 87 |
+
if txt.shape[0] == 1 and bs > 1:
|
| 88 |
+
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
| 89 |
+
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
| 90 |
+
|
| 91 |
+
vec = clip(prompt)
|
| 92 |
+
if vec.shape[0] == 1 and bs > 1:
|
| 93 |
+
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
| 94 |
+
|
| 95 |
+
return {
|
| 96 |
+
"img": img,
|
| 97 |
+
"img_ids": img_ids.to(img.device),
|
| 98 |
+
"txt": txt.to(img.device),
|
| 99 |
+
"txt_ids": txt_ids.to(img.device),
|
| 100 |
+
"vec": vec.to(img.device),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
| 105 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_lin_function(
|
| 109 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
| 110 |
+
) -> Callable[[float], float]:
|
| 111 |
+
m = (y2 - y1) / (x2 - x1)
|
| 112 |
+
b = y1 - m * x1
|
| 113 |
+
return lambda x: m * x + b
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_schedule(
|
| 117 |
+
num_steps: int,
|
| 118 |
+
image_seq_len: int,
|
| 119 |
+
base_shift: float = 0.5,
|
| 120 |
+
max_shift: float = 1.15,
|
| 121 |
+
shift: bool = True,
|
| 122 |
+
) -> list[float]:
|
| 123 |
+
# extra step for zero
|
| 124 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
| 125 |
+
|
| 126 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
| 127 |
+
if shift:
|
| 128 |
+
# estimate mu based on linear estimation between two points
|
| 129 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
| 130 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
| 131 |
+
|
| 132 |
+
return timesteps.tolist()
|
| 133 |
+
|
| 134 |
+
def denoise(
|
| 135 |
+
model: Flux,
|
| 136 |
+
# model input
|
| 137 |
+
img: Tensor,
|
| 138 |
+
img_ids: Tensor,
|
| 139 |
+
txt: Tensor,
|
| 140 |
+
txt_ids: Tensor,
|
| 141 |
+
vec: Tensor,
|
| 142 |
+
# sampling parameters
|
| 143 |
+
timesteps: list[float],
|
| 144 |
+
inverse,
|
| 145 |
+
info,
|
| 146 |
+
guidance: float = 4.0,
|
| 147 |
+
trainable_noise_list=None,
|
| 148 |
+
):
|
| 149 |
+
# this is ignored for schnell
|
| 150 |
+
inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
if inverse:
|
| 154 |
+
timesteps = timesteps[::-1]
|
| 155 |
+
inject_list = inject_list[::-1]
|
| 156 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
| 157 |
+
|
| 158 |
+
step_list = []
|
| 159 |
+
attn_map_list = []
|
| 160 |
+
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
|
| 161 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
| 162 |
+
info['t'] = t_prev if inverse else t_curr
|
| 163 |
+
info['inverse'] = inverse
|
| 164 |
+
info['second_order'] = False
|
| 165 |
+
info['inject'] = inject_list[i]
|
| 166 |
+
# when editing add optim latent for several steps
|
| 167 |
+
if trainable_noise_list and i != 0 and i<len(trainable_noise_list):
|
| 168 |
+
# smask = info['source_mask'].squeeze(0)
|
| 169 |
+
# img = trainable_noise_list[i]*smask+img*(1-smask)
|
| 170 |
+
img = trainable_noise_list[i]
|
| 171 |
+
|
| 172 |
+
pred, info, attn_maps_mid = model(
|
| 173 |
+
img=img,
|
| 174 |
+
img_ids=img_ids,
|
| 175 |
+
txt=txt,
|
| 176 |
+
txt_ids=txt_ids,
|
| 177 |
+
y=vec,
|
| 178 |
+
timesteps=t_vec,
|
| 179 |
+
guidance=guidance_vec,
|
| 180 |
+
info=info
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
img_mid = img + (t_prev - t_curr) / 2 * pred
|
| 184 |
+
|
| 185 |
+
t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
|
| 186 |
+
info['second_order'] = True
|
| 187 |
+
pred_mid, info, attn_maps = model(
|
| 188 |
+
img=img_mid,
|
| 189 |
+
img_ids=img_ids,
|
| 190 |
+
txt=txt,
|
| 191 |
+
txt_ids=txt_ids,
|
| 192 |
+
y=vec,
|
| 193 |
+
timesteps=t_vec_mid,
|
| 194 |
+
guidance=guidance_vec,
|
| 195 |
+
info=info
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2)
|
| 199 |
+
img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order
|
| 200 |
+
|
| 201 |
+
# return attnmaps L,1,512,N
|
| 202 |
+
step_list.append(t_curr)
|
| 203 |
+
attn_map_list.append((attn_maps_mid+attn_maps)/2)
|
| 204 |
+
|
| 205 |
+
attn_map_list = torch.stack(attn_map_list)
|
| 206 |
+
return img, info, step_list, attn_map_list
|
| 207 |
+
|
| 208 |
+
selected_layers = range(8,44)
|
| 209 |
+
|
| 210 |
+
def gaussian_smooth(attnmap,wh,kernel_size=3,sigma=0.5):
|
| 211 |
+
# to 2d
|
| 212 |
+
attnmap = rearrange(
|
| 213 |
+
attnmap,
|
| 214 |
+
"b (w h) -> b (w) (h)",
|
| 215 |
+
w=math.ceil(wh[0]/16),
|
| 216 |
+
h=math.ceil(wh[1]/16),
|
| 217 |
+
)
|
| 218 |
+
attnmap = attnmap.unsqueeze(1)
|
| 219 |
+
# prepare kernel
|
| 220 |
+
ax = torch.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1., device=attnmap.device)
|
| 221 |
+
xx, yy = torch.meshgrid(ax, ax, indexing='ij')
|
| 222 |
+
kernel = torch.exp(-(xx**2 + yy**2) / (2. * sigma**2))
|
| 223 |
+
kernel = kernel / kernel.sum()
|
| 224 |
+
kernel = kernel.view(1, 1, kernel_size, kernel_size)
|
| 225 |
+
kernel = kernel.to(dtype=attnmap.dtype)
|
| 226 |
+
# gaussian smooth
|
| 227 |
+
attnmap_smoothed = F.conv2d(attnmap, kernel, padding=kernel_size // 2)
|
| 228 |
+
return attnmap_smoothed.view(attnmap_smoothed.shape[0], -1)
|
| 229 |
+
|
| 230 |
+
def compute_attn_max_loss(attnmaps,source_mask,wh):
|
| 231 |
+
# attnmaps L,1,N,k
|
| 232 |
+
attnmaps = attnmaps[selected_layers,0,:,:]
|
| 233 |
+
attnmaps = attnmaps.mean(dim=-1)
|
| 234 |
+
src_mask = source_mask.view(-1).unsqueeze(0)
|
| 235 |
+
p = attnmaps*src_mask
|
| 236 |
+
p = gaussian_smooth(p, wh, kernel_size=3, sigma=0.5)
|
| 237 |
+
p = p.max(dim=1).values
|
| 238 |
+
loss = (1 - p).mean()
|
| 239 |
+
return loss
|
| 240 |
+
|
| 241 |
+
def compute_attn_min_loss(attnmaps,source_mask,wh):
|
| 242 |
+
# attnmaps L,1,N,k
|
| 243 |
+
attnmaps = attnmaps[selected_layers,0,:,:]
|
| 244 |
+
attnmaps = attnmaps.mean(dim=-1)
|
| 245 |
+
src_mask = source_mask.view(-1).unsqueeze(0)
|
| 246 |
+
p = attnmaps*src_mask
|
| 247 |
+
p = gaussian_smooth(p, wh, kernel_size=3, sigma=0.5)
|
| 248 |
+
p = p.max(dim=1).values
|
| 249 |
+
loss = p.mean()
|
| 250 |
+
return loss
|
| 251 |
+
|
| 252 |
+
def denoise_with_noise_optim(
|
| 253 |
+
model: Flux,
|
| 254 |
+
# model input
|
| 255 |
+
img: Tensor,
|
| 256 |
+
img_ids: Tensor,
|
| 257 |
+
txt: Tensor,
|
| 258 |
+
txt_ids: Tensor,
|
| 259 |
+
vec: Tensor,
|
| 260 |
+
# loss cal
|
| 261 |
+
token_ids: list[list[int]],
|
| 262 |
+
source_mask: Tensor,
|
| 263 |
+
training_steps: int,
|
| 264 |
+
training_epochs: int,
|
| 265 |
+
learning_rate: float,
|
| 266 |
+
seed: int,
|
| 267 |
+
noise_scale: float,
|
| 268 |
+
# sampling parameters
|
| 269 |
+
timesteps: list[float],
|
| 270 |
+
info,
|
| 271 |
+
guidance: float = 4.0
|
| 272 |
+
):
|
| 273 |
+
# this is ignored for schnell
|
| 274 |
+
#print(f'training the noise in last {training_steps} steps and {training_epochs} epochs')
|
| 275 |
+
#timesteps = timesteps[::-1]
|
| 276 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
| 277 |
+
|
| 278 |
+
step_list = []
|
| 279 |
+
attn_map_list = []
|
| 280 |
+
trainable_noise_list = []
|
| 281 |
+
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
|
| 282 |
+
if i >= training_steps:
|
| 283 |
+
break
|
| 284 |
+
# prepare ori parameters
|
| 285 |
+
ori_txt = txt.clone()
|
| 286 |
+
ori_img = img.clone()
|
| 287 |
+
ori_vec = vec.clone()
|
| 288 |
+
|
| 289 |
+
# prepare trainable noise
|
| 290 |
+
if i == 0:
|
| 291 |
+
if noise_scale == 0:
|
| 292 |
+
trainable_noise = torch.nn.Parameter(img.clone().detach(), requires_grad=True)
|
| 293 |
+
else:
|
| 294 |
+
noise = torch.randn(img.shape,device=img.device,dtype=img.dtype,generator=torch.Generator(device=img.device).manual_seed(seed))
|
| 295 |
+
noise = img*(1-source_mask[0])+ noise_scale*noise*source_mask[0] + (1-noise_scale)*img*source_mask[0]
|
| 296 |
+
trainable_noise = torch.nn.Parameter(noise.clone().detach(), requires_grad=True)
|
| 297 |
+
else:
|
| 298 |
+
trainable_noise = torch.nn.Parameter(img.clone().detach(), requires_grad=True)
|
| 299 |
+
optimizer = optim.Adam([trainable_noise], lr=learning_rate)
|
| 300 |
+
|
| 301 |
+
# run one training step
|
| 302 |
+
for j in range(training_epochs):
|
| 303 |
+
optimizer.zero_grad()
|
| 304 |
+
txt = ori_txt.clone().detach()
|
| 305 |
+
vec = ori_vec.clone().detach()
|
| 306 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
| 307 |
+
info['t'] = t_prev
|
| 308 |
+
info['inverse'] = False
|
| 309 |
+
info['second_order'] = False
|
| 310 |
+
info['inject'] = False # tried True, seems not necessary
|
| 311 |
+
pred, info, attn_maps_mid = model(
|
| 312 |
+
img=trainable_noise,
|
| 313 |
+
img_ids=img_ids,
|
| 314 |
+
txt=txt,
|
| 315 |
+
txt_ids=txt_ids,
|
| 316 |
+
y=vec,
|
| 317 |
+
timesteps=t_vec,
|
| 318 |
+
guidance=guidance_vec,
|
| 319 |
+
info=info
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
img_mid = trainable_noise + (t_prev - t_curr) / 2 * pred
|
| 324 |
+
|
| 325 |
+
t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
|
| 326 |
+
info['second_order'] = True
|
| 327 |
+
pred_mid, info, attn_maps = model(
|
| 328 |
+
img=img_mid,
|
| 329 |
+
img_ids=img_ids,
|
| 330 |
+
txt=txt,
|
| 331 |
+
txt_ids=txt_ids,
|
| 332 |
+
y=vec,
|
| 333 |
+
timesteps=t_vec_mid,
|
| 334 |
+
guidance=guidance_vec,
|
| 335 |
+
info=info
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2)
|
| 339 |
+
img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order
|
| 340 |
+
|
| 341 |
+
# attnmaps L,1,N,512 for cal loss
|
| 342 |
+
attn_maps=(attn_maps_mid+attn_maps)/2
|
| 343 |
+
total_loss = 0.0
|
| 344 |
+
for indices,change,ratio in token_ids:
|
| 345 |
+
if change:
|
| 346 |
+
total_loss += compute_attn_max_loss(attn_maps[:,:,:,indices], source_mask, info['wh'])
|
| 347 |
+
else:
|
| 348 |
+
if ratio != 0:
|
| 349 |
+
total_loss += ratio*compute_attn_min_loss(attn_maps[:,:,:,indices], source_mask, info['wh'])
|
| 350 |
+
total_loss.backward()
|
| 351 |
+
with torch.no_grad():
|
| 352 |
+
trainable_noise.grad *= source_mask[0]
|
| 353 |
+
optimizer.step()
|
| 354 |
+
print(f"Time {t_curr:.4f} Step {j+1}/{training_epochs}, Loss: {total_loss.item():.6f}")
|
| 355 |
+
|
| 356 |
+
attn_map_list.append(attn_maps.detach())
|
| 357 |
+
step_list.append(t_curr)
|
| 358 |
+
trainable_noise = trainable_noise.detach()
|
| 359 |
+
trainable_noise_list.append(trainable_noise.clone())
|
| 360 |
+
|
| 361 |
+
attn_map_list = torch.stack(attn_map_list)
|
| 362 |
+
return img, info, step_list, attn_map_list, trainable_noise_list
|
| 363 |
+
|
| 364 |
+
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
| 365 |
+
return rearrange(
|
| 366 |
+
x,
|
| 367 |
+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
| 368 |
+
h=math.ceil(height / 16),
|
| 369 |
+
w=math.ceil(width / 16),
|
| 370 |
+
ph=2,
|
| 371 |
+
pw=2,
|
| 372 |
+
)
|
flux/util_lore.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
from imwatermark import WatermarkEncoder
|
| 8 |
+
from safetensors.torch import load_file as load_sft
|
| 9 |
+
|
| 10 |
+
from flux.model_lore import Flux, FluxParams
|
| 11 |
+
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
| 12 |
+
from flux.modules.conditioner_lore import HFEmbedder
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ModelSpec:
|
| 17 |
+
params: FluxParams
|
| 18 |
+
ae_params: AutoEncoderParams
|
| 19 |
+
ckpt_path: str | None
|
| 20 |
+
ae_path: str | None
|
| 21 |
+
repo_id: str | None
|
| 22 |
+
repo_flow: str | None
|
| 23 |
+
repo_ae: str | None
|
| 24 |
+
|
| 25 |
+
# download model from hf
|
| 26 |
+
flux_path = "black-forest-labs/FLUX.1-dev"
|
| 27 |
+
flux_ckpt_path = os.getenv("FLUX_DEV")
|
| 28 |
+
flux_ae_path = os.getenv("AE")
|
| 29 |
+
t5_path = "google/t5-v1_1-xxl"
|
| 30 |
+
clip_path = "openai/clip-vit-large-patch14"
|
| 31 |
+
|
| 32 |
+
configs = {
|
| 33 |
+
"flux-dev": ModelSpec(
|
| 34 |
+
repo_id=flux_path,
|
| 35 |
+
repo_flow="flux1-dev.safetensors",
|
| 36 |
+
repo_ae="ae.safetensors",
|
| 37 |
+
ckpt_path=flux_ckpt_path,
|
| 38 |
+
params=FluxParams(
|
| 39 |
+
in_channels=64,
|
| 40 |
+
vec_in_dim=768,
|
| 41 |
+
context_in_dim=4096,
|
| 42 |
+
hidden_size=3072,
|
| 43 |
+
mlp_ratio=4.0,
|
| 44 |
+
num_heads=24,
|
| 45 |
+
depth=19,
|
| 46 |
+
depth_single_blocks=38,
|
| 47 |
+
axes_dim=[16, 56, 56],
|
| 48 |
+
theta=10_000,
|
| 49 |
+
qkv_bias=True,
|
| 50 |
+
guidance_embed=True,
|
| 51 |
+
),
|
| 52 |
+
ae_path=flux_ae_path,
|
| 53 |
+
ae_params=AutoEncoderParams(
|
| 54 |
+
resolution=256,
|
| 55 |
+
in_channels=3,
|
| 56 |
+
ch=128,
|
| 57 |
+
out_ch=3,
|
| 58 |
+
ch_mult=[1, 2, 4, 4],
|
| 59 |
+
num_res_blocks=2,
|
| 60 |
+
z_channels=16,
|
| 61 |
+
scale_factor=0.3611,
|
| 62 |
+
shift_factor=0.1159,
|
| 63 |
+
),
|
| 64 |
+
),
|
| 65 |
+
"flux-schnell": ModelSpec(
|
| 66 |
+
repo_id="black-forest-labs/FLUX.1-schnell",
|
| 67 |
+
repo_flow="flux1-schnell.safetensors",
|
| 68 |
+
repo_ae="ae.safetensors",
|
| 69 |
+
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
| 70 |
+
params=FluxParams(
|
| 71 |
+
in_channels=64,
|
| 72 |
+
vec_in_dim=768,
|
| 73 |
+
context_in_dim=4096,
|
| 74 |
+
hidden_size=3072,
|
| 75 |
+
mlp_ratio=4.0,
|
| 76 |
+
num_heads=24,
|
| 77 |
+
depth=19,
|
| 78 |
+
depth_single_blocks=38,
|
| 79 |
+
axes_dim=[16, 56, 56],
|
| 80 |
+
theta=10_000,
|
| 81 |
+
qkv_bias=True,
|
| 82 |
+
guidance_embed=False,
|
| 83 |
+
),
|
| 84 |
+
ae_path=os.getenv("AE"),
|
| 85 |
+
ae_params=AutoEncoderParams(
|
| 86 |
+
resolution=256,
|
| 87 |
+
in_channels=3,
|
| 88 |
+
ch=128,
|
| 89 |
+
out_ch=3,
|
| 90 |
+
ch_mult=[1, 2, 4, 4],
|
| 91 |
+
num_res_blocks=2,
|
| 92 |
+
z_channels=16,
|
| 93 |
+
scale_factor=0.3611,
|
| 94 |
+
shift_factor=0.1159,
|
| 95 |
+
),
|
| 96 |
+
),
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
| 101 |
+
if len(missing) > 0 and len(unexpected) > 0:
|
| 102 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
| 103 |
+
print("\n" + "-" * 79 + "\n")
|
| 104 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
| 105 |
+
elif len(missing) > 0:
|
| 106 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
| 107 |
+
elif len(unexpected) > 0:
|
| 108 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
| 112 |
+
# Loading Flux
|
| 113 |
+
print("Init model")
|
| 114 |
+
|
| 115 |
+
ckpt_path = configs[name].ckpt_path
|
| 116 |
+
if (
|
| 117 |
+
ckpt_path is None
|
| 118 |
+
and configs[name].repo_id is not None
|
| 119 |
+
and configs[name].repo_flow is not None
|
| 120 |
+
and hf_download
|
| 121 |
+
):
|
| 122 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
| 123 |
+
|
| 124 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
| 125 |
+
model = Flux(configs[name].params).to(torch.bfloat16)
|
| 126 |
+
|
| 127 |
+
if ckpt_path is not None:
|
| 128 |
+
print("Loading checkpoint on", device, ckpt_path)
|
| 129 |
+
# load_sft doesn't support torch.device
|
| 130 |
+
sd = load_sft(ckpt_path, device=str(device))
|
| 131 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
| 132 |
+
print_load_warning(missing, unexpected)
|
| 133 |
+
return model
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
| 137 |
+
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
| 138 |
+
return HFEmbedder(t5_path, max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
| 142 |
+
return HFEmbedder(clip_path, max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|
| 146 |
+
ckpt_path = configs[name].ae_path
|
| 147 |
+
if (
|
| 148 |
+
ckpt_path is None
|
| 149 |
+
and configs[name].repo_id is not None
|
| 150 |
+
and configs[name].repo_ae is not None
|
| 151 |
+
and hf_download
|
| 152 |
+
):
|
| 153 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
|
| 154 |
+
|
| 155 |
+
# Loading the autoencoder
|
| 156 |
+
print("Init AE")
|
| 157 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
| 158 |
+
ae = AutoEncoder(configs[name].ae_params)
|
| 159 |
+
|
| 160 |
+
if ckpt_path is not None:
|
| 161 |
+
sd = load_sft(ckpt_path, device=str(device))
|
| 162 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
| 163 |
+
print_load_warning(missing, unexpected)
|
| 164 |
+
return ae
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class WatermarkEmbedder:
|
| 168 |
+
def __init__(self, watermark):
|
| 169 |
+
self.watermark = watermark
|
| 170 |
+
self.num_bits = len(WATERMARK_BITS)
|
| 171 |
+
self.encoder = WatermarkEncoder()
|
| 172 |
+
self.encoder.set_watermark("bits", self.watermark)
|
| 173 |
+
|
| 174 |
+
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
| 175 |
+
"""
|
| 176 |
+
Adds a predefined watermark to the input image
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
image: ([N,] B, RGB, H, W) in range [-1, 1]
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
same as input but watermarked
|
| 183 |
+
"""
|
| 184 |
+
image = 0.5 * image + 0.5
|
| 185 |
+
squeeze = len(image.shape) == 4
|
| 186 |
+
if squeeze:
|
| 187 |
+
image = image[None, ...]
|
| 188 |
+
n = image.shape[0]
|
| 189 |
+
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
|
| 190 |
+
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
| 191 |
+
# watermarking libary expects input as cv2 BGR format
|
| 192 |
+
for k in range(image_np.shape[0]):
|
| 193 |
+
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
| 194 |
+
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
|
| 195 |
+
image.device
|
| 196 |
+
)
|
| 197 |
+
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
| 198 |
+
if squeeze:
|
| 199 |
+
image = image[0]
|
| 200 |
+
image = 2 * image - 1
|
| 201 |
+
return image
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# A fixed 48-bit message that was chosen at random
|
| 205 |
+
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
|
| 206 |
+
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
| 207 |
+
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
| 208 |
+
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pydantic==2.10.6
|
| 2 |
+
torch
|
| 3 |
+
einops
|
| 4 |
+
accelerate==0.34.2
|
| 5 |
+
einops==0.8.0
|
| 6 |
+
transformers==4.41.2
|
| 7 |
+
huggingface-hub==0.24.6
|
| 8 |
+
datasets
|
| 9 |
+
omegaconf
|
| 10 |
+
diffusers
|
| 11 |
+
sentencepiece
|
| 12 |
+
opencv-python
|
| 13 |
+
matplotlib
|
| 14 |
+
onnxruntime
|
| 15 |
+
torchvision
|
| 16 |
+
timm
|
| 17 |
+
invisible-watermark
|
| 18 |
+
fire
|
| 19 |
+
tqdm
|