Spaces:
Runtime error
Runtime error
Update
Browse files- app.py +269 -0
- flux/__init__.py +0 -0
- flux/math.py +94 -0
- flux/model.py +120 -0
- flux/modules/attention_flax.py +494 -0
- flux/modules/autoencoder.py +343 -0
- flux/modules/conditioner.py +73 -0
- flux/modules/layers.py +292 -0
- flux/sampling.py +227 -0
- flux/util.py +340 -0
- flux/wrapper.py +139 -0
- requirements.txt +14 -0
app.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from io import BytesIO
|
4 |
+
import uuid
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import gradio as gr
|
8 |
+
import spaces
|
9 |
+
import numpy as np
|
10 |
+
from einops import rearrange
|
11 |
+
from PIL import Image, ExifTags
|
12 |
+
|
13 |
+
from dataclasses import dataclass
|
14 |
+
|
15 |
+
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack, prepare_tokens
|
16 |
+
from flux.util import configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5
|
17 |
+
|
18 |
+
|
19 |
+
import jax
|
20 |
+
import jax.numpy as jnp
|
21 |
+
from flax import nnx
|
22 |
+
from jax import Array as Tensor
|
23 |
+
from einops import repeat
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class SamplingOptions:
|
27 |
+
prompt: str
|
28 |
+
width: int
|
29 |
+
height: int
|
30 |
+
num_steps: int
|
31 |
+
guidance: float
|
32 |
+
seed: int | None
|
33 |
+
|
34 |
+
NSFW_THRESHOLD = 0.85
|
35 |
+
|
36 |
+
@spaces.GPU
|
37 |
+
def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
|
38 |
+
t5 = load_t5(device, max_length=256 if is_schnell else 512)
|
39 |
+
clip = load_clip(device)
|
40 |
+
model = load_flow_model(name, device="cpu" if offload else device)
|
41 |
+
ae = load_ae(name, device="cpu" if offload else device)
|
42 |
+
# nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
|
43 |
+
# return model, ae, t5, clip, nsfw_classifier
|
44 |
+
return nnx.split(model), nnx.split(ae), nnx.split(t5), t5.tokenizer, nnx.split(clip), clip.tokenizer, None
|
45 |
+
|
46 |
+
@jax.jit
|
47 |
+
def encode(ae,x):
|
48 |
+
ae=nnx.merge(*ae)
|
49 |
+
return ae.encode(x)
|
50 |
+
|
51 |
+
def _generate(model, ae, t5, clip, x, t5_tokens, clip_tokens, num_steps, guidance,
|
52 |
+
#init_image=None,
|
53 |
+
#image2image_strength=0.0,
|
54 |
+
shift=True):
|
55 |
+
b,h,w,c=x.shape
|
56 |
+
model=nnx.merge(*model)
|
57 |
+
ae=nnx.merge(*ae)
|
58 |
+
t5=nnx.merge(*t5)
|
59 |
+
clip=nnx.merge(*clip)
|
60 |
+
timesteps = get_schedule(
|
61 |
+
num_steps,
|
62 |
+
x.shape[-1] * x.shape[-2] // 4,
|
63 |
+
shift=shift,
|
64 |
+
)
|
65 |
+
# if init_image is not None:
|
66 |
+
# t_idx = int((1 - image2image_strength) * num_steps)
|
67 |
+
# t = timesteps[t_idx]
|
68 |
+
# timesteps = timesteps[t_idx:]
|
69 |
+
# x = t * x + (1.0 - t) * init_image.astype(x.dtype)
|
70 |
+
inp = prepare(t5, clip, x, t5_tokens, clip_tokens)
|
71 |
+
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
|
72 |
+
x = unpack(x.astype(jnp.float32), h*8, w*8)
|
73 |
+
x = ae.decode(x)
|
74 |
+
return x
|
75 |
+
|
76 |
+
generate=jax.jit(_generate, static_argnames=("num_steps","shift"))
|
77 |
+
|
78 |
+
|
79 |
+
def prepare_tokens(t5_tokenizer, clip_tokenizer, prompt: str | list[str]) -> tuple[Tensor, Tensor]:
|
80 |
+
if isinstance(prompt, str):
|
81 |
+
prompt = [prompt]
|
82 |
+
t5_tokens = t5_tokenizer(
|
83 |
+
prompt,
|
84 |
+
truncation=True,
|
85 |
+
max_length=512,
|
86 |
+
return_length=False,
|
87 |
+
return_overflowing_tokens=False,
|
88 |
+
padding="max_length",
|
89 |
+
return_tensors="jax",
|
90 |
+
)["input_ids"]
|
91 |
+
clip_tokens = clip_tokenizer(
|
92 |
+
prompt,
|
93 |
+
truncation=True,
|
94 |
+
max_length=77,
|
95 |
+
return_length=False,
|
96 |
+
return_overflowing_tokens=False,
|
97 |
+
padding="max_length",
|
98 |
+
return_tensors="jax",
|
99 |
+
)["input_ids"]
|
100 |
+
return t5_tokens, clip_tokens
|
101 |
+
|
102 |
+
|
103 |
+
class FluxGenerator:
|
104 |
+
def __init__(self, model_name: str, device: str, offload: bool):
|
105 |
+
self.device = None
|
106 |
+
self.offload = offload
|
107 |
+
self.model_name = model_name
|
108 |
+
self.is_schnell = model_name == "flux-schnell"
|
109 |
+
self.model, self.ae, self.t5, self.t5_tokenizer, self.clip, self.clip_tokenizer, self.nsfw_classifier = get_models(
|
110 |
+
model_name,
|
111 |
+
device=self.device,
|
112 |
+
offload=self.offload,
|
113 |
+
is_schnell=self.is_schnell,
|
114 |
+
)
|
115 |
+
self.key = jax.random.key(0)
|
116 |
+
|
117 |
+
@spaces.GPU(duration=180)
|
118 |
+
def generate_image(
|
119 |
+
self,
|
120 |
+
img_size,
|
121 |
+
num_steps,
|
122 |
+
guidance,
|
123 |
+
seed,
|
124 |
+
prompt,
|
125 |
+
# init_image=None,
|
126 |
+
# image2image_strength=0.0,
|
127 |
+
add_sampling_metadata=True,
|
128 |
+
):
|
129 |
+
seed = int(seed)
|
130 |
+
if seed == -1:
|
131 |
+
seed = None
|
132 |
+
if img_size == "1,024x1,024":
|
133 |
+
width, height = 1024, 1024
|
134 |
+
else:
|
135 |
+
width, height = 512, 512
|
136 |
+
|
137 |
+
opts = SamplingOptions(
|
138 |
+
prompt=prompt,
|
139 |
+
width=width,
|
140 |
+
height=height,
|
141 |
+
num_steps=num_steps,
|
142 |
+
guidance=guidance,
|
143 |
+
seed=seed,
|
144 |
+
)
|
145 |
+
|
146 |
+
if opts.seed is None:
|
147 |
+
# opts.seed = torch.Generator(device="cpu").seed()
|
148 |
+
key,self.key=jax.random.split(self.key,2)
|
149 |
+
opts.seed=jax.random.randint(key,(),0,2**30)
|
150 |
+
print(f"Generating '{opts.prompt}' with seed {opts.seed}")
|
151 |
+
t0 = time.perf_counter()
|
152 |
+
|
153 |
+
# if init_image is not None:
|
154 |
+
# if isinstance(init_image, np.ndarray):
|
155 |
+
# init_image = jnp.asarray(init_image).astype(jnp.float32) / 255.0
|
156 |
+
# init_image = init_image[None]
|
157 |
+
# # init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width))
|
158 |
+
# init_image = jax.image.resize(init_image, (opts.height, opts.width), method="lanczos5")
|
159 |
+
# # if self.offload:
|
160 |
+
# # self.ae.encoder.to(self.device)
|
161 |
+
# # init_image = self.ae.encode(init_image)
|
162 |
+
# init_image = encode(self.ae, init_image)
|
163 |
+
|
164 |
+
# prepare input
|
165 |
+
t5_tokens, clip_tokens = prepare_tokens(self.t5_tokenizer, self.clip_tokenizer, prompt=opts.prompt)
|
166 |
+
x = get_noise(
|
167 |
+
1,
|
168 |
+
opts.height,
|
169 |
+
opts.width,
|
170 |
+
device=None,
|
171 |
+
dtype=jnp.bfloat16,
|
172 |
+
seed=opts.seed,
|
173 |
+
)
|
174 |
+
|
175 |
+
x = generate(self.model, self.ae, self.t5, self.clip, x, t5_tokens, clip_tokens, opts.num_steps, opts.guidance, shift=(not self.is_schnell))
|
176 |
+
|
177 |
+
t1 = time.perf_counter()
|
178 |
+
# print(f"Done in {t1 - t0:.1f}s.")
|
179 |
+
runtime = t1 - t0
|
180 |
+
# print(f"Done in {t1 - t0:.1f}s.")
|
181 |
+
# bring into PIL format
|
182 |
+
x= jnp.clip(x, -1, 1)
|
183 |
+
# x = embed_watermark(x.astype(jnp.float32))
|
184 |
+
# x = rearrange(x[0], "c h w -> h w c")
|
185 |
+
img = Image.fromarray(np.asarray((127.5 * (x[0] + 1.0))).astype(np.uint8))
|
186 |
+
# img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
187 |
+
# nsfw_score = [x["score"] for x in self.nsfw_classifier(img) if x["label"] == "nsfw"][0]
|
188 |
+
|
189 |
+
if True:
|
190 |
+
filename = f"output/gradio/{uuid.uuid4()}.jpg"
|
191 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
192 |
+
exif_data = Image.Exif()
|
193 |
+
# if init_image is None:
|
194 |
+
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
195 |
+
# else:
|
196 |
+
# exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
|
197 |
+
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
198 |
+
exif_data[ExifTags.Base.Model] = self.model_name
|
199 |
+
if add_sampling_metadata:
|
200 |
+
exif_data[ExifTags.Base.ImageDescription] = prompt
|
201 |
+
|
202 |
+
img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
|
203 |
+
|
204 |
+
return img, runtime, str(opts.seed), filename, None
|
205 |
+
else:
|
206 |
+
return None, str(opts.seed), None, "Your generated image may contain NSFW content."
|
207 |
+
|
208 |
+
@spaces.GPU(duration=300)
|
209 |
+
def create_demo(model_name: str, device: str = "cuda", offload: bool = False):
|
210 |
+
generator = FluxGenerator(model_name, device, offload)
|
211 |
+
is_schnell = model_name == "flux-schnell"
|
212 |
+
|
213 |
+
with open("./assets/banner.html") as f:
|
214 |
+
banner = f.read()
|
215 |
+
with gr.Blocks() as demo:
|
216 |
+
with gr.Column(elem_id="app-container"):
|
217 |
+
gr.HTML(f"""<iframe scrolling="no" style="width: 100%; height: 125px; border: 0" srcdoc='{banner}'>""")
|
218 |
+
gr.Markdown(f"""🚀 [Flux-Flax](https://github.com/lkwq007/flux-flax) is a JAX implementation of Flux models. 1-step time statistics for `FLUX.1-schnell`: `0.4s` for 1024x1024, `0.1s` for 512x512; 2-step: `0.6s` for 1024x1024, `0.2s` for 512x512; 4-step: `2.4s` for 1024x1024, `0.8s` for 512x512.
|
219 |
+
""")
|
220 |
+
|
221 |
+
with gr.Row():
|
222 |
+
with gr.Column(scale=3):
|
223 |
+
output_image = gr.Image(label="Generated Image")
|
224 |
+
warning_text = gr.Textbox(label="Warning", visible=False)
|
225 |
+
download_btn = gr.File(label="Download full-resolution")
|
226 |
+
gr.Markdown("""
|
227 |
+
💡 Note: More resolutions are supports, but here this demo limits to 1024x1024 and 512x512 to avoid jit recompilation (which takes 130s). Flux-Flax also support `FLUX.1-dev`, 50-step time statistics: `18s` for 1024x1024, `6s` for 512x512""")
|
228 |
+
with gr.Column(scale=1):
|
229 |
+
prompt = gr.Textbox(label="Prompt", value="a photo of a forest with mist swirling around the tree trunks. The word \"FLUX\" is painted over it in big, red brush strokes with visible texture")
|
230 |
+
generate_btn = gr.Button("Generate")
|
231 |
+
with gr.Row():
|
232 |
+
seed_output = gr.Number(label="Used Seed")
|
233 |
+
runtime = gr.Number(label="Inference Time", precision=3)
|
234 |
+
with gr.Row():
|
235 |
+
seed = gr.Textbox(-1, label="Seed (-1 for random)")
|
236 |
+
img_size = gr.Radio(["1,024x1,024", "512x512"], label="Image Resolution", value="1,024x1,024")
|
237 |
+
num_steps = gr.Slider(1, 4, 1, step=1, label="Number of steps")
|
238 |
+
add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True)
|
239 |
+
guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell, visible=False)
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
# def update_img2img(do_img2img):
|
245 |
+
# return {
|
246 |
+
# init_image: gr.update(visible=do_img2img),
|
247 |
+
# image2image_strength: gr.update(visible=do_img2img),
|
248 |
+
# }
|
249 |
+
|
250 |
+
# do_img2img.change(update_img2img, do_img2img, [init_image, image2image_strength])
|
251 |
+
generate_btn.click(
|
252 |
+
fn=generator.generate_image,
|
253 |
+
inputs=[img_size, num_steps, guidance, seed, prompt, add_sampling_metadata],
|
254 |
+
outputs=[output_image, runtime, seed_output, download_btn, warning_text],
|
255 |
+
)
|
256 |
+
|
257 |
+
return demo
|
258 |
+
|
259 |
+
# if __name__ == "__main__":
|
260 |
+
# import argparse
|
261 |
+
# parser = argparse.ArgumentParser(description="Flux")
|
262 |
+
# parser.add_argument("--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name")
|
263 |
+
# parser.add_argument("--device", type=str, default="cpu", help="Device to use")
|
264 |
+
# parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
|
265 |
+
# parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
|
266 |
+
# args = parser.parse_args()
|
267 |
+
|
268 |
+
demo = create_demo("flux-schnell", None, False)
|
269 |
+
demo.launch()
|
flux/__init__.py
ADDED
File without changes
|
flux/math.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import torch
|
2 |
+
import math
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
from einops import rearrange
|
6 |
+
from flax import nnx
|
7 |
+
|
8 |
+
Tensor=jax.Array
|
9 |
+
|
10 |
+
def check_tpu():
|
11 |
+
return any('TPU' in d.device_kind for d in jax.devices())
|
12 |
+
|
13 |
+
# from torch import Tensor
|
14 |
+
if check_tpu():
|
15 |
+
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
|
16 |
+
# q, # [batch_size, num_heads, q_seq_len, d_model]
|
17 |
+
# k, # [batch_size, num_heads, kv_seq_len, d_model]
|
18 |
+
# v, # [batch_size, num_heads, kv_seq_len, d_model]
|
19 |
+
def flash_mha(q, k, v):
|
20 |
+
return flash_attention(q, k, v, sm_scale=1/math.sqrt(q.shape[-1]))
|
21 |
+
else:
|
22 |
+
from jax.experimental.pallas.ops.gpu.attention import mha, mha_reference
|
23 |
+
def pallas_mha(q, k, v):
|
24 |
+
# B L H D
|
25 |
+
# return mha_reference(q,k,v,segment_ids=None,sm_scale=1/math.sqrt(q.shape[-1]))
|
26 |
+
q_len=q.shape[1]
|
27 |
+
diff=(-q_len)&127
|
28 |
+
segment_ids=jnp.zeros((q.shape[0],q.shape[1]),dtype=jnp.int32)
|
29 |
+
segment_ids=jnp.pad(segment_ids,((0,0),(0,diff)),mode="constant",constant_values=1)
|
30 |
+
# q,k,v=map(lambda x: jnp.pad(x,((0,0),(0,diff),(0,0),(0,0)),mode="constant", constant_values=0),(q,k,v))
|
31 |
+
return mha(q,k,v,segment_ids=segment_ids,sm_scale=1/math.sqrt(q.shape[-1]))#[:,:q_len]
|
32 |
+
# mha: batch_size, seq_len, num_heads, head_dim = q.shape
|
33 |
+
from functools import partial
|
34 |
+
from flux.modules.attention_flax import jax_memory_efficient_attention
|
35 |
+
try:
|
36 |
+
from flash_attn_jax import flash_mha
|
37 |
+
except:
|
38 |
+
flash_mha = pallas_mha
|
39 |
+
# flash_mha = nnx.dot_product_attention
|
40 |
+
|
41 |
+
|
42 |
+
def dot_product_attention(q, k, v, sm_scale=1.0):
|
43 |
+
q,k,v=map(lambda x: rearrange(x, "b h n d -> b n h d"), (q,k,v))
|
44 |
+
# ret = pallas_mha(q,k,v)
|
45 |
+
ret = nnx.dot_product_attention(q,k,v)
|
46 |
+
# if q.shape[-3] % 64 == 0:
|
47 |
+
# query_chunk_size = int(q.shape[-3] / 64)
|
48 |
+
# elif q.shape[-3] % 16 == 0:
|
49 |
+
# query_chunk_size = int(q.shape[-3] / 16)
|
50 |
+
# elif q.shape[-3] % 4 == 0:
|
51 |
+
# query_chunk_size = int(q.shape[-3] / 4)
|
52 |
+
# else:
|
53 |
+
# query_chunk_size = int(q.shape[-3])
|
54 |
+
# ret=jax_memory_efficient_attention(q, k, v, query_chunk_size=query_chunk_size)
|
55 |
+
return rearrange(ret, "b n h d -> b h n d")
|
56 |
+
|
57 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
58 |
+
q, k = apply_rope(q, k, pe)
|
59 |
+
# x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
60 |
+
# q is B H L D
|
61 |
+
q,k,v=map(lambda x: rearrange(x, "B H L D -> B L H D"), (q,k,v))
|
62 |
+
# x = nnx.dot_product_attention(q,k,v)
|
63 |
+
x = flash_mha(q,k,v)
|
64 |
+
# x = pallas_mha(q,k,v)
|
65 |
+
# x = mha(q,k,v,None,sm_scale=1/math.sqrt(q.shape[-1]))
|
66 |
+
x = rearrange(x, "B L H D -> B L (H D)")
|
67 |
+
|
68 |
+
# x = rearrange(x, "B H L D -> B L (H D)")
|
69 |
+
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
74 |
+
assert dim % 2 == 0
|
75 |
+
# scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
76 |
+
scale = jnp.arange(0, dim, 2, dtype=jnp.float32) / dim
|
77 |
+
omega = 1.0 / (theta**scale)
|
78 |
+
# out = torch.einsum("...n,d->...nd", pos, omega)
|
79 |
+
out = jnp.einsum("...n,d->...nd", pos.astype(jnp.float32), omega)
|
80 |
+
# out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
81 |
+
out = jnp.stack([jnp.cos(out), -jnp.sin(out), jnp.sin(out), jnp.cos(out)], axis=-1)
|
82 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
83 |
+
# return out.float()
|
84 |
+
return out.astype(jnp.float32)
|
85 |
+
|
86 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
87 |
+
# xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
88 |
+
# xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
89 |
+
xq_ = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 1, 2)
|
90 |
+
xk_ = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 1, 2)
|
91 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
92 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
93 |
+
# return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
94 |
+
return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype)
|
flux/model.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
import jax.numpy as jnp
|
5 |
+
from jax import Array as Tensor
|
6 |
+
from flax import nnx
|
7 |
+
|
8 |
+
from flux.wrapper import TorchWrapper
|
9 |
+
|
10 |
+
from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
11 |
+
MLPEmbedder, SingleStreamBlock,
|
12 |
+
timestep_embedding)
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class FluxParams:
|
17 |
+
in_channels: int
|
18 |
+
vec_in_dim: int
|
19 |
+
context_in_dim: int
|
20 |
+
hidden_size: int
|
21 |
+
mlp_ratio: float
|
22 |
+
num_heads: int
|
23 |
+
depth: int
|
24 |
+
depth_single_blocks: int
|
25 |
+
axes_dim: list[int]
|
26 |
+
theta: int
|
27 |
+
qkv_bias: bool
|
28 |
+
guidance_embed: bool
|
29 |
+
|
30 |
+
|
31 |
+
DoubleStreamBlock_class, EmbedND_class, LastLayer_class, MLPEmbedder_class, SingleStreamBlock_class = DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock
|
32 |
+
|
33 |
+
class Flux(nnx.Module):
|
34 |
+
"""
|
35 |
+
Transformer model for flow matching on sequences.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, params: FluxParams, dtype: jnp.dtype = jnp.float32, rngs: nnx.Rngs = None):
|
39 |
+
nn = TorchWrapper(rngs=rngs, dtype=dtype)
|
40 |
+
DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock = nn.declare_with_rng(DoubleStreamBlock_class, EmbedND_class, LastLayer_class, MLPEmbedder_class, SingleStreamBlock_class)
|
41 |
+
self.params = params
|
42 |
+
self.in_channels = params.in_channels
|
43 |
+
self.out_channels = self.in_channels
|
44 |
+
if params.hidden_size % params.num_heads != 0:
|
45 |
+
raise ValueError(
|
46 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
47 |
+
)
|
48 |
+
pe_dim = params.hidden_size // params.num_heads
|
49 |
+
if sum(params.axes_dim) != pe_dim:
|
50 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
51 |
+
self.hidden_size = params.hidden_size
|
52 |
+
self.num_heads = params.num_heads
|
53 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
54 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
55 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
56 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
57 |
+
self.guidance_in = (
|
58 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
59 |
+
)
|
60 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
61 |
+
|
62 |
+
self.double_blocks = nn.ModuleList(
|
63 |
+
[
|
64 |
+
DoubleStreamBlock(
|
65 |
+
self.hidden_size,
|
66 |
+
self.num_heads,
|
67 |
+
mlp_ratio=params.mlp_ratio,
|
68 |
+
qkv_bias=params.qkv_bias,
|
69 |
+
)
|
70 |
+
for _ in range(params.depth)
|
71 |
+
]
|
72 |
+
)
|
73 |
+
|
74 |
+
self.single_blocks = nn.ModuleList(
|
75 |
+
[
|
76 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
77 |
+
for _ in range(params.depth_single_blocks)
|
78 |
+
]
|
79 |
+
)
|
80 |
+
|
81 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
82 |
+
|
83 |
+
def __call__(
|
84 |
+
self,
|
85 |
+
img: Tensor,
|
86 |
+
img_ids: Tensor,
|
87 |
+
txt: Tensor,
|
88 |
+
txt_ids: Tensor,
|
89 |
+
timesteps: Tensor,
|
90 |
+
y: Tensor,
|
91 |
+
guidance: Tensor | None = None,
|
92 |
+
) -> Tensor:
|
93 |
+
if img.ndim != 3 or txt.ndim != 3:
|
94 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
95 |
+
|
96 |
+
# running on sequences img
|
97 |
+
img = self.img_in(img)
|
98 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
99 |
+
if self.params.guidance_embed:
|
100 |
+
if guidance is None:
|
101 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
102 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
103 |
+
vec = vec + self.vector_in(y)
|
104 |
+
txt = self.txt_in(txt)
|
105 |
+
|
106 |
+
# ids = torch.cat((txt_ids, img_ids), dim=1)
|
107 |
+
ids = jnp.concatenate((txt_ids, img_ids), axis=1)
|
108 |
+
pe = self.pe_embedder(ids)
|
109 |
+
|
110 |
+
for block in self.double_blocks:
|
111 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
112 |
+
|
113 |
+
# img = torch.cat((txt, img), 1)
|
114 |
+
img = jnp.concatenate((txt, img), axis=1)
|
115 |
+
for block in self.single_blocks:
|
116 |
+
img = block(img, vec=vec, pe=pe)
|
117 |
+
img = img[:, txt.shape[1] :, ...]
|
118 |
+
|
119 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
120 |
+
return img
|
flux/modules/attention_flax.py
ADDED
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import functools
|
16 |
+
import math
|
17 |
+
|
18 |
+
import flax.linen as nn
|
19 |
+
import jax
|
20 |
+
import jax.numpy as jnp
|
21 |
+
|
22 |
+
|
23 |
+
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
|
24 |
+
"""Multi-head dot product attention with a limited number of queries."""
|
25 |
+
num_kv, num_heads, k_features = key.shape[-3:]
|
26 |
+
v_features = value.shape[-1]
|
27 |
+
key_chunk_size = min(key_chunk_size, num_kv)
|
28 |
+
query = query / jnp.sqrt(k_features)
|
29 |
+
|
30 |
+
@functools.partial(jax.checkpoint, prevent_cse=False)
|
31 |
+
def summarize_chunk(query, key, value):
|
32 |
+
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
|
33 |
+
|
34 |
+
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
|
35 |
+
max_score = jax.lax.stop_gradient(max_score)
|
36 |
+
exp_weights = jnp.exp(attn_weights - max_score)
|
37 |
+
|
38 |
+
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
|
39 |
+
max_score = jnp.einsum("...qhk->...qh", max_score)
|
40 |
+
|
41 |
+
return (exp_values, exp_weights.sum(axis=-1), max_score)
|
42 |
+
|
43 |
+
def chunk_scanner(chunk_idx):
|
44 |
+
# julienne key array
|
45 |
+
key_chunk = jax.lax.dynamic_slice(
|
46 |
+
operand=key,
|
47 |
+
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
|
48 |
+
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
|
49 |
+
)
|
50 |
+
|
51 |
+
# julienne value array
|
52 |
+
value_chunk = jax.lax.dynamic_slice(
|
53 |
+
operand=value,
|
54 |
+
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
|
55 |
+
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
|
56 |
+
)
|
57 |
+
|
58 |
+
return summarize_chunk(query, key_chunk, value_chunk)
|
59 |
+
|
60 |
+
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
|
61 |
+
|
62 |
+
global_max = jnp.max(chunk_max, axis=0, keepdims=True)
|
63 |
+
max_diffs = jnp.exp(chunk_max - global_max)
|
64 |
+
|
65 |
+
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
|
66 |
+
chunk_weights *= max_diffs
|
67 |
+
|
68 |
+
all_values = chunk_values.sum(axis=0)
|
69 |
+
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
|
70 |
+
|
71 |
+
return all_values / all_weights
|
72 |
+
|
73 |
+
|
74 |
+
def jax_memory_efficient_attention(
|
75 |
+
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
|
76 |
+
):
|
77 |
+
r"""
|
78 |
+
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
|
79 |
+
https://github.com/AminRezaei0x443/memory-efficient-attention
|
80 |
+
|
81 |
+
Args:
|
82 |
+
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
|
83 |
+
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
|
84 |
+
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
|
85 |
+
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
|
86 |
+
numerical precision for computation
|
87 |
+
query_chunk_size (`int`, *optional*, defaults to 1024):
|
88 |
+
chunk size to divide query array value must divide query_length equally without remainder
|
89 |
+
key_chunk_size (`int`, *optional*, defaults to 4096):
|
90 |
+
chunk size to divide key and value array value must divide key_value_length equally without remainder
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
|
94 |
+
"""
|
95 |
+
num_q, num_heads, q_features = query.shape[-3:]
|
96 |
+
|
97 |
+
def chunk_scanner(chunk_idx, _):
|
98 |
+
# julienne query array
|
99 |
+
query_chunk = jax.lax.dynamic_slice(
|
100 |
+
operand=query,
|
101 |
+
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
|
102 |
+
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
|
103 |
+
)
|
104 |
+
|
105 |
+
return (
|
106 |
+
chunk_idx + query_chunk_size, # unused ignore it
|
107 |
+
_query_chunk_attention(
|
108 |
+
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
|
109 |
+
),
|
110 |
+
)
|
111 |
+
|
112 |
+
_, res = jax.lax.scan(
|
113 |
+
f=chunk_scanner,
|
114 |
+
init=0,
|
115 |
+
xs=None,
|
116 |
+
length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
|
117 |
+
)
|
118 |
+
|
119 |
+
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
|
120 |
+
|
121 |
+
|
122 |
+
class FlaxAttention(nn.Module):
|
123 |
+
r"""
|
124 |
+
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
|
125 |
+
|
126 |
+
Parameters:
|
127 |
+
query_dim (:obj:`int`):
|
128 |
+
Input hidden states dimension
|
129 |
+
heads (:obj:`int`, *optional*, defaults to 8):
|
130 |
+
Number of heads
|
131 |
+
dim_head (:obj:`int`, *optional*, defaults to 64):
|
132 |
+
Hidden states dimension inside each head
|
133 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
134 |
+
Dropout rate
|
135 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
136 |
+
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
137 |
+
split_head_dim (`bool`, *optional*, defaults to `False`):
|
138 |
+
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
139 |
+
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
140 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
141 |
+
Parameters `dtype`
|
142 |
+
|
143 |
+
"""
|
144 |
+
|
145 |
+
query_dim: int
|
146 |
+
heads: int = 8
|
147 |
+
dim_head: int = 64
|
148 |
+
dropout: float = 0.0
|
149 |
+
use_memory_efficient_attention: bool = False
|
150 |
+
split_head_dim: bool = False
|
151 |
+
dtype: jnp.dtype = jnp.float32
|
152 |
+
|
153 |
+
def setup(self):
|
154 |
+
inner_dim = self.dim_head * self.heads
|
155 |
+
self.scale = self.dim_head**-0.5
|
156 |
+
|
157 |
+
# Weights were exported with old names {to_q, to_k, to_v, to_out}
|
158 |
+
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
|
159 |
+
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
|
160 |
+
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
|
161 |
+
|
162 |
+
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
|
163 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
164 |
+
|
165 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
166 |
+
batch_size, seq_len, dim = tensor.shape
|
167 |
+
head_size = self.heads
|
168 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
169 |
+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
170 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
171 |
+
return tensor
|
172 |
+
|
173 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
174 |
+
batch_size, seq_len, dim = tensor.shape
|
175 |
+
head_size = self.heads
|
176 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
177 |
+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
178 |
+
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
|
179 |
+
return tensor
|
180 |
+
|
181 |
+
def __call__(self, hidden_states, context=None, deterministic=True):
|
182 |
+
context = hidden_states if context is None else context
|
183 |
+
|
184 |
+
query_proj = self.query(hidden_states)
|
185 |
+
key_proj = self.key(context)
|
186 |
+
value_proj = self.value(context)
|
187 |
+
|
188 |
+
if self.split_head_dim:
|
189 |
+
b = hidden_states.shape[0]
|
190 |
+
query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
|
191 |
+
key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
|
192 |
+
value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
|
193 |
+
else:
|
194 |
+
query_states = self.reshape_heads_to_batch_dim(query_proj)
|
195 |
+
key_states = self.reshape_heads_to_batch_dim(key_proj)
|
196 |
+
value_states = self.reshape_heads_to_batch_dim(value_proj)
|
197 |
+
|
198 |
+
if self.use_memory_efficient_attention:
|
199 |
+
query_states = query_states.transpose(1, 0, 2)
|
200 |
+
key_states = key_states.transpose(1, 0, 2)
|
201 |
+
value_states = value_states.transpose(1, 0, 2)
|
202 |
+
|
203 |
+
# this if statement create a chunk size for each layer of the unet
|
204 |
+
# the chunk size is equal to the query_length dimension of the deepest layer of the unet
|
205 |
+
|
206 |
+
flatten_latent_dim = query_states.shape[-3]
|
207 |
+
if flatten_latent_dim % 64 == 0:
|
208 |
+
query_chunk_size = int(flatten_latent_dim / 64)
|
209 |
+
elif flatten_latent_dim % 16 == 0:
|
210 |
+
query_chunk_size = int(flatten_latent_dim / 16)
|
211 |
+
elif flatten_latent_dim % 4 == 0:
|
212 |
+
query_chunk_size = int(flatten_latent_dim / 4)
|
213 |
+
else:
|
214 |
+
query_chunk_size = int(flatten_latent_dim)
|
215 |
+
|
216 |
+
hidden_states = jax_memory_efficient_attention(
|
217 |
+
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
|
218 |
+
)
|
219 |
+
|
220 |
+
hidden_states = hidden_states.transpose(1, 0, 2)
|
221 |
+
else:
|
222 |
+
# compute attentions
|
223 |
+
if self.split_head_dim:
|
224 |
+
attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
|
225 |
+
else:
|
226 |
+
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
|
227 |
+
|
228 |
+
attention_scores = attention_scores * self.scale
|
229 |
+
attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
|
230 |
+
|
231 |
+
# attend to values
|
232 |
+
if self.split_head_dim:
|
233 |
+
hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
|
234 |
+
b = hidden_states.shape[0]
|
235 |
+
hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
|
236 |
+
else:
|
237 |
+
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
|
238 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
239 |
+
|
240 |
+
hidden_states = self.proj_attn(hidden_states)
|
241 |
+
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
242 |
+
|
243 |
+
|
244 |
+
class FlaxBasicTransformerBlock(nn.Module):
|
245 |
+
r"""
|
246 |
+
A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
|
247 |
+
https://arxiv.org/abs/1706.03762
|
248 |
+
|
249 |
+
|
250 |
+
Parameters:
|
251 |
+
dim (:obj:`int`):
|
252 |
+
Inner hidden states dimension
|
253 |
+
n_heads (:obj:`int`):
|
254 |
+
Number of heads
|
255 |
+
d_head (:obj:`int`):
|
256 |
+
Hidden states dimension inside each head
|
257 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
258 |
+
Dropout rate
|
259 |
+
only_cross_attention (`bool`, defaults to `False`):
|
260 |
+
Whether to only apply cross attention.
|
261 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
262 |
+
Parameters `dtype`
|
263 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
264 |
+
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
265 |
+
split_head_dim (`bool`, *optional*, defaults to `False`):
|
266 |
+
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
267 |
+
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
268 |
+
"""
|
269 |
+
|
270 |
+
dim: int
|
271 |
+
n_heads: int
|
272 |
+
d_head: int
|
273 |
+
dropout: float = 0.0
|
274 |
+
only_cross_attention: bool = False
|
275 |
+
dtype: jnp.dtype = jnp.float32
|
276 |
+
use_memory_efficient_attention: bool = False
|
277 |
+
split_head_dim: bool = False
|
278 |
+
|
279 |
+
def setup(self):
|
280 |
+
# self attention (or cross_attention if only_cross_attention is True)
|
281 |
+
self.attn1 = FlaxAttention(
|
282 |
+
self.dim,
|
283 |
+
self.n_heads,
|
284 |
+
self.d_head,
|
285 |
+
self.dropout,
|
286 |
+
self.use_memory_efficient_attention,
|
287 |
+
self.split_head_dim,
|
288 |
+
dtype=self.dtype,
|
289 |
+
)
|
290 |
+
# cross attention
|
291 |
+
self.attn2 = FlaxAttention(
|
292 |
+
self.dim,
|
293 |
+
self.n_heads,
|
294 |
+
self.d_head,
|
295 |
+
self.dropout,
|
296 |
+
self.use_memory_efficient_attention,
|
297 |
+
self.split_head_dim,
|
298 |
+
dtype=self.dtype,
|
299 |
+
)
|
300 |
+
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
|
301 |
+
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
302 |
+
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
303 |
+
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
304 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
305 |
+
|
306 |
+
def __call__(self, hidden_states, context, deterministic=True):
|
307 |
+
# self attention
|
308 |
+
residual = hidden_states
|
309 |
+
if self.only_cross_attention:
|
310 |
+
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
|
311 |
+
else:
|
312 |
+
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
|
313 |
+
hidden_states = hidden_states + residual
|
314 |
+
|
315 |
+
# cross attention
|
316 |
+
residual = hidden_states
|
317 |
+
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
|
318 |
+
hidden_states = hidden_states + residual
|
319 |
+
|
320 |
+
# feed forward
|
321 |
+
residual = hidden_states
|
322 |
+
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
|
323 |
+
hidden_states = hidden_states + residual
|
324 |
+
|
325 |
+
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
326 |
+
|
327 |
+
|
328 |
+
class FlaxTransformer2DModel(nn.Module):
|
329 |
+
r"""
|
330 |
+
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
|
331 |
+
https://arxiv.org/pdf/1506.02025.pdf
|
332 |
+
|
333 |
+
|
334 |
+
Parameters:
|
335 |
+
in_channels (:obj:`int`):
|
336 |
+
Input number of channels
|
337 |
+
n_heads (:obj:`int`):
|
338 |
+
Number of heads
|
339 |
+
d_head (:obj:`int`):
|
340 |
+
Hidden states dimension inside each head
|
341 |
+
depth (:obj:`int`, *optional*, defaults to 1):
|
342 |
+
Number of transformers block
|
343 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
344 |
+
Dropout rate
|
345 |
+
use_linear_projection (`bool`, defaults to `False`): tbd
|
346 |
+
only_cross_attention (`bool`, defaults to `False`): tbd
|
347 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
348 |
+
Parameters `dtype`
|
349 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
350 |
+
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
351 |
+
split_head_dim (`bool`, *optional*, defaults to `False`):
|
352 |
+
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
353 |
+
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
354 |
+
"""
|
355 |
+
|
356 |
+
in_channels: int
|
357 |
+
n_heads: int
|
358 |
+
d_head: int
|
359 |
+
depth: int = 1
|
360 |
+
dropout: float = 0.0
|
361 |
+
use_linear_projection: bool = False
|
362 |
+
only_cross_attention: bool = False
|
363 |
+
dtype: jnp.dtype = jnp.float32
|
364 |
+
use_memory_efficient_attention: bool = False
|
365 |
+
split_head_dim: bool = False
|
366 |
+
|
367 |
+
def setup(self):
|
368 |
+
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
369 |
+
|
370 |
+
inner_dim = self.n_heads * self.d_head
|
371 |
+
if self.use_linear_projection:
|
372 |
+
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
|
373 |
+
else:
|
374 |
+
self.proj_in = nn.Conv(
|
375 |
+
inner_dim,
|
376 |
+
kernel_size=(1, 1),
|
377 |
+
strides=(1, 1),
|
378 |
+
padding="VALID",
|
379 |
+
dtype=self.dtype,
|
380 |
+
)
|
381 |
+
|
382 |
+
self.transformer_blocks = [
|
383 |
+
FlaxBasicTransformerBlock(
|
384 |
+
inner_dim,
|
385 |
+
self.n_heads,
|
386 |
+
self.d_head,
|
387 |
+
dropout=self.dropout,
|
388 |
+
only_cross_attention=self.only_cross_attention,
|
389 |
+
dtype=self.dtype,
|
390 |
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
391 |
+
split_head_dim=self.split_head_dim,
|
392 |
+
)
|
393 |
+
for _ in range(self.depth)
|
394 |
+
]
|
395 |
+
|
396 |
+
if self.use_linear_projection:
|
397 |
+
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
|
398 |
+
else:
|
399 |
+
self.proj_out = nn.Conv(
|
400 |
+
inner_dim,
|
401 |
+
kernel_size=(1, 1),
|
402 |
+
strides=(1, 1),
|
403 |
+
padding="VALID",
|
404 |
+
dtype=self.dtype,
|
405 |
+
)
|
406 |
+
|
407 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
408 |
+
|
409 |
+
def __call__(self, hidden_states, context, deterministic=True):
|
410 |
+
batch, height, width, channels = hidden_states.shape
|
411 |
+
residual = hidden_states
|
412 |
+
hidden_states = self.norm(hidden_states)
|
413 |
+
if self.use_linear_projection:
|
414 |
+
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
415 |
+
hidden_states = self.proj_in(hidden_states)
|
416 |
+
else:
|
417 |
+
hidden_states = self.proj_in(hidden_states)
|
418 |
+
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
419 |
+
|
420 |
+
for transformer_block in self.transformer_blocks:
|
421 |
+
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
|
422 |
+
|
423 |
+
if self.use_linear_projection:
|
424 |
+
hidden_states = self.proj_out(hidden_states)
|
425 |
+
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
426 |
+
else:
|
427 |
+
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
428 |
+
hidden_states = self.proj_out(hidden_states)
|
429 |
+
|
430 |
+
hidden_states = hidden_states + residual
|
431 |
+
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
432 |
+
|
433 |
+
|
434 |
+
class FlaxFeedForward(nn.Module):
|
435 |
+
r"""
|
436 |
+
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
|
437 |
+
[`FeedForward`] class, with the following simplifications:
|
438 |
+
- The activation function is currently hardcoded to a gated linear unit from:
|
439 |
+
https://arxiv.org/abs/2002.05202
|
440 |
+
- `dim_out` is equal to `dim`.
|
441 |
+
- The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
|
442 |
+
|
443 |
+
Parameters:
|
444 |
+
dim (:obj:`int`):
|
445 |
+
Inner hidden states dimension
|
446 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
447 |
+
Dropout rate
|
448 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
449 |
+
Parameters `dtype`
|
450 |
+
"""
|
451 |
+
|
452 |
+
dim: int
|
453 |
+
dropout: float = 0.0
|
454 |
+
dtype: jnp.dtype = jnp.float32
|
455 |
+
|
456 |
+
def setup(self):
|
457 |
+
# The second linear layer needs to be called
|
458 |
+
# net_2 for now to match the index of the Sequential layer
|
459 |
+
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
|
460 |
+
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
|
461 |
+
|
462 |
+
def __call__(self, hidden_states, deterministic=True):
|
463 |
+
hidden_states = self.net_0(hidden_states, deterministic=deterministic)
|
464 |
+
hidden_states = self.net_2(hidden_states)
|
465 |
+
return hidden_states
|
466 |
+
|
467 |
+
|
468 |
+
class FlaxGEGLU(nn.Module):
|
469 |
+
r"""
|
470 |
+
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
|
471 |
+
https://arxiv.org/abs/2002.05202.
|
472 |
+
|
473 |
+
Parameters:
|
474 |
+
dim (:obj:`int`):
|
475 |
+
Input hidden states dimension
|
476 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
477 |
+
Dropout rate
|
478 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
479 |
+
Parameters `dtype`
|
480 |
+
"""
|
481 |
+
|
482 |
+
dim: int
|
483 |
+
dropout: float = 0.0
|
484 |
+
dtype: jnp.dtype = jnp.float32
|
485 |
+
|
486 |
+
def setup(self):
|
487 |
+
inner_dim = self.dim * 4
|
488 |
+
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
489 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
490 |
+
|
491 |
+
def __call__(self, hidden_states, deterministic=True):
|
492 |
+
hidden_states = self.proj(hidden_states)
|
493 |
+
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
|
494 |
+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
|
flux/modules/autoencoder.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from einops import rearrange
|
4 |
+
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
from jax import Array as Tensor
|
8 |
+
from flax import nnx
|
9 |
+
|
10 |
+
from flux.wrapper import TorchWrapper
|
11 |
+
from flux.math import dot_product_attention
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class AutoEncoderParams:
|
16 |
+
resolution: int
|
17 |
+
in_channels: int
|
18 |
+
ch: int
|
19 |
+
out_ch: int
|
20 |
+
ch_mult: list[int]
|
21 |
+
num_res_blocks: int
|
22 |
+
z_channels: int
|
23 |
+
scale_factor: float
|
24 |
+
shift_factor: float
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
swish = nnx.swish
|
29 |
+
|
30 |
+
|
31 |
+
class AttnBlock(nnx.Module):
|
32 |
+
def __init__(self, in_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
33 |
+
nn = TorchWrapper(rngs, dtype=dtype)
|
34 |
+
self.in_channels = in_channels
|
35 |
+
|
36 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
37 |
+
|
38 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
39 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
40 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
41 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
42 |
+
|
43 |
+
def attention(self, h_: Tensor) -> Tensor:
|
44 |
+
h_ = self.norm(h_)
|
45 |
+
q = self.q(h_)
|
46 |
+
k = self.k(h_)
|
47 |
+
v = self.v(h_)
|
48 |
+
|
49 |
+
# b, c, h, w = q.shape
|
50 |
+
b, h, w, c = q.shape
|
51 |
+
|
52 |
+
# q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
53 |
+
# k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
54 |
+
# v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
55 |
+
q = rearrange(q, "b h w c -> b 1 (h w) c")
|
56 |
+
k = rearrange(k, "b h w c -> b 1 (h w) c")
|
57 |
+
v = rearrange(v, "b h w c -> b 1 (h w) c")
|
58 |
+
# h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
59 |
+
h_ = dot_product_attention(q, k, v)
|
60 |
+
|
61 |
+
|
62 |
+
# return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
63 |
+
return rearrange(h_, "b 1 (h w) c -> b h w c", h=h, w=w, c=c, b=b)
|
64 |
+
|
65 |
+
def __call__(self, x: Tensor) -> Tensor:
|
66 |
+
return x + self.proj_out(self.attention(x))
|
67 |
+
|
68 |
+
|
69 |
+
class ResnetBlock(nnx.Module):
|
70 |
+
def __init__(self, in_channels: int, out_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
71 |
+
nn = TorchWrapper(rngs, dtype=dtype)
|
72 |
+
self.in_channels = in_channels
|
73 |
+
out_channels = in_channels if out_channels is None else out_channels
|
74 |
+
self.out_channels = out_channels
|
75 |
+
|
76 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
77 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
78 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
79 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
80 |
+
if self.in_channels != self.out_channels:
|
81 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
82 |
+
|
83 |
+
def __call__(self, x):
|
84 |
+
h = x
|
85 |
+
h = self.norm1(h)
|
86 |
+
h = swish(h)
|
87 |
+
h = self.conv1(h)
|
88 |
+
|
89 |
+
h = self.norm2(h)
|
90 |
+
h = swish(h)
|
91 |
+
h = self.conv2(h)
|
92 |
+
|
93 |
+
if self.in_channels != self.out_channels:
|
94 |
+
x = self.nin_shortcut(x)
|
95 |
+
|
96 |
+
return x + h
|
97 |
+
|
98 |
+
|
99 |
+
class Downsample(nnx.Module):
|
100 |
+
def __init__(self, in_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
101 |
+
nn = TorchWrapper(rngs, dtype=dtype)
|
102 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
103 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
104 |
+
|
105 |
+
def __call__(self, x: Tensor):
|
106 |
+
# pad = (0, 1, 0, 1)
|
107 |
+
# x = nn.functional.pad(x, pad, mode="constant", value=0)
|
108 |
+
x = jnp.pad(x, ((0, 0), (0, 1), (0, 1), (0, 0)), mode="constant")
|
109 |
+
x = self.conv(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class Upsample(nnx.Module):
|
114 |
+
def __init__(self, in_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
115 |
+
nn = TorchWrapper(rngs, dtype=dtype)
|
116 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
117 |
+
|
118 |
+
def __call__(self, x: Tensor):
|
119 |
+
# x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
120 |
+
B, H, W, C = x.shape
|
121 |
+
x = jax.image.resize(x, (B, H * 2, W * 2, C), method="nearest")
|
122 |
+
x = self.conv(x)
|
123 |
+
return x
|
124 |
+
|
125 |
+
ResnetBlock_class, Downsample_class, Upsample_class, AttnBlock_class = ResnetBlock, Downsample, Upsample, AttnBlock
|
126 |
+
|
127 |
+
class Encoder(nnx.Module):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
resolution: int,
|
131 |
+
in_channels: int,
|
132 |
+
ch: int,
|
133 |
+
ch_mult: list[int],
|
134 |
+
num_res_blocks: int,
|
135 |
+
z_channels: int,
|
136 |
+
dtype=jnp.float32,
|
137 |
+
rngs: nnx.Rngs = None
|
138 |
+
):
|
139 |
+
nn = TorchWrapper(rngs, dtype=dtype)
|
140 |
+
ResnetBlock, Downsample, Upsample, AttnBlock = nn.declare_with_rng(ResnetBlock_class, Downsample_class, Upsample_class, AttnBlock_class)
|
141 |
+
self.ch = ch
|
142 |
+
self.num_resolutions = len(ch_mult)
|
143 |
+
self.num_res_blocks = num_res_blocks
|
144 |
+
self.resolution = resolution
|
145 |
+
self.in_channels = in_channels
|
146 |
+
# downsampling
|
147 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
148 |
+
|
149 |
+
curr_res = resolution
|
150 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
151 |
+
self.in_ch_mult = in_ch_mult
|
152 |
+
self.down = nn.ModuleList()
|
153 |
+
block_in = self.ch
|
154 |
+
for i_level in range(self.num_resolutions):
|
155 |
+
block = nn.ModuleList()
|
156 |
+
attn = nn.ModuleList()
|
157 |
+
block_in = ch * in_ch_mult[i_level]
|
158 |
+
block_out = ch * ch_mult[i_level]
|
159 |
+
for _ in range(self.num_res_blocks):
|
160 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
161 |
+
block_in = block_out
|
162 |
+
down = nn.Module()
|
163 |
+
down.block = block
|
164 |
+
down.attn = attn
|
165 |
+
if i_level != self.num_resolutions - 1:
|
166 |
+
down.downsample = Downsample(block_in)
|
167 |
+
curr_res = curr_res // 2
|
168 |
+
self.down.append(down)
|
169 |
+
|
170 |
+
# middle
|
171 |
+
self.mid = nn.Module()
|
172 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
173 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
174 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
175 |
+
|
176 |
+
# end
|
177 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
178 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
179 |
+
|
180 |
+
def __call__(self, x: Tensor) -> Tensor:
|
181 |
+
# downsampling
|
182 |
+
hs = [self.conv_in(x)]
|
183 |
+
for i_level in range(self.num_resolutions):
|
184 |
+
for i_block in range(self.num_res_blocks):
|
185 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
186 |
+
if len(self.down[i_level].attn) > 0:
|
187 |
+
h = self.down[i_level].attn[i_block](h)
|
188 |
+
hs.append(h)
|
189 |
+
if i_level != self.num_resolutions - 1:
|
190 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
191 |
+
|
192 |
+
# middle
|
193 |
+
h = hs[-1]
|
194 |
+
h = self.mid.block_1(h)
|
195 |
+
h = self.mid.attn_1(h)
|
196 |
+
h = self.mid.block_2(h)
|
197 |
+
# end
|
198 |
+
h = self.norm_out(h)
|
199 |
+
h = swish(h)
|
200 |
+
h = self.conv_out(h)
|
201 |
+
return h
|
202 |
+
|
203 |
+
|
204 |
+
class Decoder(nnx.Module):
|
205 |
+
def __init__(
|
206 |
+
self,
|
207 |
+
ch: int,
|
208 |
+
out_ch: int,
|
209 |
+
ch_mult: list[int],
|
210 |
+
num_res_blocks: int,
|
211 |
+
in_channels: int,
|
212 |
+
resolution: int,
|
213 |
+
z_channels: int,
|
214 |
+
dtype=jnp.float32,
|
215 |
+
rngs: nnx.Rngs = None
|
216 |
+
):
|
217 |
+
nn = TorchWrapper(rngs, dtype=dtype)
|
218 |
+
ResnetBlock, Downsample, Upsample, AttnBlock = nn.declare_with_rng(ResnetBlock_class, Downsample_class, Upsample_class, AttnBlock_class)
|
219 |
+
self.ch = ch
|
220 |
+
self.num_resolutions = len(ch_mult)
|
221 |
+
self.num_res_blocks = num_res_blocks
|
222 |
+
self.resolution = resolution
|
223 |
+
self.in_channels = in_channels
|
224 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
225 |
+
|
226 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
227 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
228 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
229 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
230 |
+
|
231 |
+
# z to block_in
|
232 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
233 |
+
|
234 |
+
# middle
|
235 |
+
self.mid = nn.Module()
|
236 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
237 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
238 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
239 |
+
|
240 |
+
# upsampling
|
241 |
+
self.up = nn.ModuleList()
|
242 |
+
for i_level in reversed(range(self.num_resolutions)):
|
243 |
+
block = nn.ModuleList()
|
244 |
+
attn = nn.ModuleList()
|
245 |
+
block_out = ch * ch_mult[i_level]
|
246 |
+
for _ in range(self.num_res_blocks + 1):
|
247 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
248 |
+
block_in = block_out
|
249 |
+
up = nn.Module()
|
250 |
+
up.block = block
|
251 |
+
up.attn = attn
|
252 |
+
if i_level != 0:
|
253 |
+
up.upsample = Upsample(block_in)
|
254 |
+
curr_res = curr_res * 2
|
255 |
+
self.up.insert(0, up) # prepend to get consistent order
|
256 |
+
|
257 |
+
# end
|
258 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
259 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
260 |
+
|
261 |
+
def __call__(self, z: Tensor) -> Tensor:
|
262 |
+
# z to block_in
|
263 |
+
h = self.conv_in(z)
|
264 |
+
|
265 |
+
# middle
|
266 |
+
h = self.mid.block_1(h)
|
267 |
+
h = self.mid.attn_1(h)
|
268 |
+
h = self.mid.block_2(h)
|
269 |
+
|
270 |
+
# upsampling
|
271 |
+
for i_level in reversed(range(self.num_resolutions)):
|
272 |
+
for i_block in range(self.num_res_blocks + 1):
|
273 |
+
h = self.up[i_level].block[i_block](h)
|
274 |
+
if len(self.up[i_level].attn) > 0:
|
275 |
+
h = self.up[i_level].attn[i_block](h)
|
276 |
+
if i_level != 0:
|
277 |
+
h = self.up[i_level].upsample(h)
|
278 |
+
|
279 |
+
# end
|
280 |
+
h = self.norm_out(h)
|
281 |
+
h = swish(h)
|
282 |
+
h = self.conv_out(h)
|
283 |
+
return h
|
284 |
+
|
285 |
+
|
286 |
+
class DiagonalGaussian(nnx.Module):
|
287 |
+
def __init__(self, sample: bool = True, chunk_dim: int = -1, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
288 |
+
self.sample = sample
|
289 |
+
self.chunk_dim = chunk_dim
|
290 |
+
self.rngs = rngs
|
291 |
+
self.dtype = dtype
|
292 |
+
|
293 |
+
def __call__(self, z: Tensor) -> Tensor:
|
294 |
+
# mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
295 |
+
mean, logvar = jnp.split(z, 2, axis=self.chunk_dim)
|
296 |
+
if self.sample:
|
297 |
+
# std = torch.exp(0.5 * logvar)
|
298 |
+
# return mean + std * torch.randn_like(mean)
|
299 |
+
std = jnp.exp(0.5 * logvar)
|
300 |
+
return mean + std * jax.random.normal(self.rngs(), mean.shape)
|
301 |
+
else:
|
302 |
+
return mean
|
303 |
+
|
304 |
+
|
305 |
+
Encoder_class, Decoder_class, DiagonalGaussian_class = Encoder, Decoder, DiagonalGaussian
|
306 |
+
|
307 |
+
class AutoEncoder(nnx.Module):
|
308 |
+
def __init__(self, params: AutoEncoderParams, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
309 |
+
nn = TorchWrapper(rngs, dtype=dtype)
|
310 |
+
Encoder, Decoder, DiagonalGaussian = nn.declare_with_rng(Encoder_class, Decoder_class, DiagonalGaussian_class)
|
311 |
+
self.encoder = Encoder(
|
312 |
+
resolution=params.resolution,
|
313 |
+
in_channels=params.in_channels,
|
314 |
+
ch=params.ch,
|
315 |
+
ch_mult=params.ch_mult,
|
316 |
+
num_res_blocks=params.num_res_blocks,
|
317 |
+
z_channels=params.z_channels,
|
318 |
+
)
|
319 |
+
self.decoder = Decoder(
|
320 |
+
resolution=params.resolution,
|
321 |
+
in_channels=params.in_channels,
|
322 |
+
ch=params.ch,
|
323 |
+
out_ch=params.out_ch,
|
324 |
+
ch_mult=params.ch_mult,
|
325 |
+
num_res_blocks=params.num_res_blocks,
|
326 |
+
z_channels=params.z_channels,
|
327 |
+
)
|
328 |
+
self.reg = DiagonalGaussian()
|
329 |
+
|
330 |
+
self.scale_factor = params.scale_factor
|
331 |
+
self.shift_factor = params.shift_factor
|
332 |
+
|
333 |
+
def encode(self, x: Tensor) -> Tensor:
|
334 |
+
z = self.reg(self.encoder(x))
|
335 |
+
z = self.scale_factor * (z - self.shift_factor)
|
336 |
+
return z
|
337 |
+
|
338 |
+
def decode(self, z: Tensor) -> Tensor:
|
339 |
+
z = z / self.scale_factor + self.shift_factor
|
340 |
+
return self.decoder(z)
|
341 |
+
|
342 |
+
def __call__(self, x: Tensor) -> Tensor:
|
343 |
+
return self.decode(self.encode(x))
|
flux/modules/conditioner.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flax import nnx
|
2 |
+
import jax.numpy as jnp
|
3 |
+
from jax import Array as Tensor
|
4 |
+
|
5 |
+
from transformers import (FlaxCLIPTextModel, CLIPTokenizer, FlaxT5EncoderModel,
|
6 |
+
T5Tokenizer)
|
7 |
+
|
8 |
+
|
9 |
+
class HFEmbedder(nnx.Module):
|
10 |
+
def __init__(self, version: str, max_length: int, **hf_kwargs):
|
11 |
+
self.is_clip = version.startswith("openai")
|
12 |
+
self.max_length = max_length
|
13 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
14 |
+
dtype = hf_kwargs.get("dtype", jnp.float32)
|
15 |
+
if self.is_clip:
|
16 |
+
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
17 |
+
# self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
18 |
+
self.hf_module: FlaxCLIPTextModel = FlaxCLIPTextModel.from_pretrained(version, **hf_kwargs)
|
19 |
+
else:
|
20 |
+
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
|
21 |
+
# self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
|
22 |
+
self.hf_module: FlaxT5EncoderModel = FlaxT5EncoderModel.from_pretrained(version, **hf_kwargs)
|
23 |
+
if dtype==jnp.bfloat16:
|
24 |
+
self.hf_module.params = self.hf_module.to_bf16(self.hf_module.params)
|
25 |
+
|
26 |
+
def tokenize(self, text: list[str]) -> Tensor:
|
27 |
+
batch_encoding = self.tokenizer(
|
28 |
+
text,
|
29 |
+
truncation=True,
|
30 |
+
max_length=self.max_length,
|
31 |
+
return_length=False,
|
32 |
+
return_overflowing_tokens=False,
|
33 |
+
padding="max_length",
|
34 |
+
return_tensors="jax",
|
35 |
+
)
|
36 |
+
return batch_encoding["input_ids"]
|
37 |
+
|
38 |
+
def __call__(self, input_ids: Tensor) -> Tensor:
|
39 |
+
# outputs = self.hf_module(
|
40 |
+
# input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
41 |
+
# attention_mask=None,
|
42 |
+
# output_hidden_states=False,
|
43 |
+
# )
|
44 |
+
outputs = self.hf_module(
|
45 |
+
input_ids=input_ids,
|
46 |
+
attention_mask=None,
|
47 |
+
output_hidden_states=False,
|
48 |
+
train=False,
|
49 |
+
)
|
50 |
+
return outputs[self.output_key]
|
51 |
+
# def __call__(self, text: list[str]) -> Tensor:
|
52 |
+
# batch_encoding = self.tokenizer(
|
53 |
+
# text,
|
54 |
+
# truncation=True,
|
55 |
+
# max_length=self.max_length,
|
56 |
+
# return_length=False,
|
57 |
+
# return_overflowing_tokens=False,
|
58 |
+
# padding="max_length",
|
59 |
+
# return_tensors="jax",
|
60 |
+
# )
|
61 |
+
|
62 |
+
# # outputs = self.hf_module(
|
63 |
+
# # input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
64 |
+
# # attention_mask=None,
|
65 |
+
# # output_hidden_states=False,
|
66 |
+
# # )
|
67 |
+
# outputs = self.hf_module(
|
68 |
+
# input_ids=batch_encoding["input_ids"],
|
69 |
+
# attention_mask=None,
|
70 |
+
# output_hidden_states=False,
|
71 |
+
# train=False,
|
72 |
+
# )
|
73 |
+
# return outputs[self.output_key]
|
flux/modules/layers.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
from jax import Array as Tensor
|
7 |
+
from flax import nnx
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
from flux.wrapper import TorchWrapper
|
11 |
+
from flux.math import attention, rope
|
12 |
+
|
13 |
+
|
14 |
+
class EmbedND(nnx.Module):
|
15 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int], dtype=jnp.float32, rngs: nnx.Rngs = None):
|
16 |
+
self.dim = dim
|
17 |
+
self.theta = theta
|
18 |
+
self.axes_dim = axes_dim
|
19 |
+
|
20 |
+
def __call__(self, ids: Tensor) -> Tensor:
|
21 |
+
n_axes = ids.shape[-1]
|
22 |
+
# emb = torch.cat(
|
23 |
+
# [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
24 |
+
# dim=-3,
|
25 |
+
# )
|
26 |
+
emb = jnp.concatenate(
|
27 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
28 |
+
axis=-3,
|
29 |
+
)
|
30 |
+
|
31 |
+
# return emb.unsqueeze(1)
|
32 |
+
return jnp.expand_dims(emb, 1)
|
33 |
+
|
34 |
+
|
35 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
36 |
+
"""
|
37 |
+
Create sinusoidal timestep embeddings.
|
38 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
39 |
+
These may be fractional.
|
40 |
+
:param dim: the dimension of the output.
|
41 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
42 |
+
:return: an (N, D) Tensor of positional embeddings.
|
43 |
+
"""
|
44 |
+
t = time_factor * t
|
45 |
+
half = dim // 2
|
46 |
+
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
47 |
+
# t.device
|
48 |
+
# )
|
49 |
+
|
50 |
+
freqs = jnp.exp(-math.log(max_period) * jnp.arange(half, dtype=jnp.float32) / half)
|
51 |
+
|
52 |
+
# args = t[:, None].float() * freqs[None]
|
53 |
+
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
54 |
+
args = t[:, None] * freqs[None]
|
55 |
+
embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
|
56 |
+
if dim % 2:
|
57 |
+
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
58 |
+
embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
|
59 |
+
# if torch.is_floating_point(t):
|
60 |
+
# embedding = embedding.to(t)
|
61 |
+
# return embedding
|
62 |
+
if jnp.issubdtype(t.dtype, jnp.floating):
|
63 |
+
embedding = embedding.astype(t.dtype)
|
64 |
+
return embedding
|
65 |
+
|
66 |
+
|
67 |
+
class MLPEmbedder(nnx.Module):
|
68 |
+
def __init__(self, in_dim: int, hidden_dim: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
69 |
+
nn = TorchWrapper(rngs=rngs, dtype=dtype)
|
70 |
+
|
71 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
72 |
+
self.silu = nn.SiLU()
|
73 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
74 |
+
|
75 |
+
def __call__(self, x: Tensor) -> Tensor:
|
76 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
77 |
+
|
78 |
+
|
79 |
+
class RMSNorm(nnx.Module):
|
80 |
+
def __init__(self, dim: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
81 |
+
nn = TorchWrapper(rngs=rngs, dtype=dtype)
|
82 |
+
# self.scale = nn.Parameter(torch.ones(dim))
|
83 |
+
self.scale = nn.Parameter(jnp.ones((dim,)))
|
84 |
+
|
85 |
+
|
86 |
+
def __call__(self, x: Tensor):
|
87 |
+
x_dtype = x.dtype
|
88 |
+
# x = x.float()
|
89 |
+
x = x.astype(jnp.float32)
|
90 |
+
# rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
91 |
+
rrms = jax.lax.rsqrt(jnp.mean(x**2, axis=-1, keepdims=True) + 1e-6)
|
92 |
+
# return (x * rrms).to(dtype=x_dtype) * self.scale
|
93 |
+
return (x * rrms).astype(x.dtype) * self.scale
|
94 |
+
|
95 |
+
|
96 |
+
RMSNorm_class = RMSNorm
|
97 |
+
|
98 |
+
class QKNorm(nnx.Module):
|
99 |
+
def __init__(self, dim: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
100 |
+
nn = TorchWrapper(rngs=rngs, dtype=dtype)
|
101 |
+
RMSNorm = nn.declare_with_rng(RMSNorm_class)
|
102 |
+
self.query_norm = RMSNorm(dim)
|
103 |
+
self.key_norm = RMSNorm(dim)
|
104 |
+
|
105 |
+
def __call__(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
106 |
+
q = self.query_norm(q)
|
107 |
+
k = self.key_norm(k)
|
108 |
+
# return q.to(v), k.to(v)
|
109 |
+
return q.astype(v.dtype), k.astype(v.dtype)
|
110 |
+
|
111 |
+
|
112 |
+
QKNorm_class = QKNorm
|
113 |
+
|
114 |
+
class SelfAttention(nnx.Module):
|
115 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
116 |
+
nn = TorchWrapper(rngs=rngs, dtype=dtype)
|
117 |
+
QKNorm = nn.declare_with_rng(QKNorm_class)
|
118 |
+
self.num_heads = num_heads
|
119 |
+
head_dim = dim // num_heads
|
120 |
+
|
121 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
122 |
+
self.norm = QKNorm(head_dim)
|
123 |
+
self.proj = nn.Linear(dim, dim)
|
124 |
+
|
125 |
+
def __call__(self, x: Tensor, pe: Tensor) -> Tensor:
|
126 |
+
qkv = self.qkv(x)
|
127 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
128 |
+
q, k = self.norm(q, k, v)
|
129 |
+
x = attention(q, k, v, pe=pe)
|
130 |
+
x = self.proj(x)
|
131 |
+
return x
|
132 |
+
|
133 |
+
|
134 |
+
@dataclass
|
135 |
+
class ModulationOut:
|
136 |
+
shift: Tensor
|
137 |
+
scale: Tensor
|
138 |
+
gate: Tensor
|
139 |
+
|
140 |
+
|
141 |
+
class Modulation(nnx.Module):
|
142 |
+
def __init__(self, dim: int, double: bool, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
143 |
+
nn = TorchWrapper(rngs=rngs, dtype=dtype)
|
144 |
+
self.is_double = double
|
145 |
+
self.multiplier = 6 if double else 3
|
146 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
147 |
+
|
148 |
+
def __call__(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
149 |
+
# out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
150 |
+
out = self.lin(nnx.silu(vec))[:, None, :]
|
151 |
+
out = jnp.split(out, self.multiplier, axis=-1)
|
152 |
+
return (
|
153 |
+
ModulationOut(*out[:3]),
|
154 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
155 |
+
)
|
156 |
+
|
157 |
+
Modulation_class, SelfAttention_class = Modulation, SelfAttention
|
158 |
+
|
159 |
+
class DoubleStreamBlock(nnx.Module):
|
160 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
161 |
+
nn = TorchWrapper(rngs=rngs, dtype=dtype)
|
162 |
+
Modulation, SelfAttention = nn.declare_with_rng(Modulation_class, SelfAttention_class)
|
163 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
164 |
+
self.num_heads = num_heads
|
165 |
+
self.hidden_size = hidden_size
|
166 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
167 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
168 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
169 |
+
|
170 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
171 |
+
self.img_mlp = nn.Sequential(
|
172 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
173 |
+
nn.GELU(approximate="tanh"),
|
174 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
175 |
+
)
|
176 |
+
|
177 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
178 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
179 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
180 |
+
|
181 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
182 |
+
self.txt_mlp = nn.Sequential(
|
183 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
184 |
+
nn.GELU(approximate="tanh"),
|
185 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
186 |
+
)
|
187 |
+
|
188 |
+
def __call__(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
189 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
190 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
191 |
+
|
192 |
+
# prepare image for attention
|
193 |
+
img_modulated = self.img_norm1(img)
|
194 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
195 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
196 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
197 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
198 |
+
|
199 |
+
# prepare txt for attention
|
200 |
+
txt_modulated = self.txt_norm1(txt)
|
201 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
202 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
203 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
204 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
205 |
+
|
206 |
+
# run actual attention
|
207 |
+
# q = torch.cat((txt_q, img_q), dim=2)
|
208 |
+
# k = torch.cat((txt_k, img_k), dim=2)
|
209 |
+
# v = torch.cat((txt_v, img_v), dim=2)
|
210 |
+
q = jnp.concatenate((txt_q, img_q), axis=2)
|
211 |
+
k = jnp.concatenate((txt_k, img_k), axis=2)
|
212 |
+
v = jnp.concatenate((txt_v, img_v), axis=2)
|
213 |
+
|
214 |
+
|
215 |
+
attn = attention(q, k, v, pe=pe)
|
216 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
217 |
+
|
218 |
+
# calculate the img bloks
|
219 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
220 |
+
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
221 |
+
|
222 |
+
# calculate the txt bloks
|
223 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
224 |
+
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
225 |
+
return img, txt
|
226 |
+
|
227 |
+
|
228 |
+
class SingleStreamBlock(nnx.Module):
|
229 |
+
"""
|
230 |
+
A DiT block with parallel linear layers as described in
|
231 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
232 |
+
"""
|
233 |
+
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
hidden_size: int,
|
237 |
+
num_heads: int,
|
238 |
+
mlp_ratio: float = 4.0,
|
239 |
+
qk_scale: float | None = None,
|
240 |
+
dtype=jnp.float32, rngs: nnx.Rngs = None
|
241 |
+
):
|
242 |
+
nn = TorchWrapper(rngs=rngs, dtype=dtype)
|
243 |
+
QKNorm, Modulation = nn.declare_with_rng(QKNorm_class, Modulation_class)
|
244 |
+
self.hidden_dim = hidden_size
|
245 |
+
self.num_heads = num_heads
|
246 |
+
head_dim = hidden_size // num_heads
|
247 |
+
self.scale = qk_scale or head_dim**-0.5
|
248 |
+
|
249 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
250 |
+
# qkv and mlp_in
|
251 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
252 |
+
# proj and mlp_out
|
253 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
254 |
+
|
255 |
+
self.norm = QKNorm(head_dim)
|
256 |
+
|
257 |
+
self.hidden_size = hidden_size
|
258 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
259 |
+
|
260 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
261 |
+
self.modulation = Modulation(hidden_size, double=False)
|
262 |
+
|
263 |
+
def __call__(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
264 |
+
mod, _ = self.modulation(vec)
|
265 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
266 |
+
# qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
267 |
+
qkv, mlp = jnp.split(self.linear1(x_mod), [3 * self.hidden_size,], axis=-1)
|
268 |
+
|
269 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
270 |
+
q, k = self.norm(q, k, v)
|
271 |
+
|
272 |
+
# compute attention
|
273 |
+
attn = attention(q, k, v, pe=pe)
|
274 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
275 |
+
# output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
276 |
+
output = self.linear2(jnp.concatenate((attn, self.mlp_act(mlp)), axis=2))
|
277 |
+
return x + mod.gate * output
|
278 |
+
|
279 |
+
|
280 |
+
class LastLayer(nnx.Module):
|
281 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
|
282 |
+
nn = TorchWrapper(rngs=rngs, dtype=dtype)
|
283 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
284 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
285 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
286 |
+
|
287 |
+
def __call__(self, x: Tensor, vec: Tensor) -> Tensor:
|
288 |
+
# shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
289 |
+
shift, scale = jnp.split(self.adaLN_modulation(vec), 2, axis=1)
|
290 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
291 |
+
x = self.linear(x)
|
292 |
+
return x
|
flux/sampling.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
|
6 |
+
import jax
|
7 |
+
import jax.numpy as jnp
|
8 |
+
from jax import Array as Tensor
|
9 |
+
from flax import nnx
|
10 |
+
|
11 |
+
from flux.model import Flux
|
12 |
+
from flux.modules.conditioner import HFEmbedder
|
13 |
+
|
14 |
+
|
15 |
+
def get_noise(
|
16 |
+
num_samples: int,
|
17 |
+
height: int,
|
18 |
+
width: int,
|
19 |
+
device,
|
20 |
+
dtype: jnp.dtype,
|
21 |
+
seed: int,
|
22 |
+
):
|
23 |
+
# return torch.randn(
|
24 |
+
# num_samples,
|
25 |
+
# 16,
|
26 |
+
# # allow for packing
|
27 |
+
# 2 * math.ceil(height / 16),
|
28 |
+
# 2 * math.ceil(width / 16),
|
29 |
+
# device=device,
|
30 |
+
# dtype=dtype,
|
31 |
+
# generator=torch.Generator(device=device).manual_seed(seed),
|
32 |
+
# )
|
33 |
+
# rngs = nnx.Rngs(seed)
|
34 |
+
key = jax.random.key(seed)
|
35 |
+
return jax.random.normal(
|
36 |
+
# rngs(),
|
37 |
+
key,
|
38 |
+
(
|
39 |
+
num_samples,
|
40 |
+
2 * math.ceil(height / 16),
|
41 |
+
2 * math.ceil(width / 16),
|
42 |
+
16,
|
43 |
+
),
|
44 |
+
dtype=dtype
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
def prepare_tokens(t5: HFEmbedder, clip: HFEmbedder, prompt: str | list[str]) -> tuple[Tensor, Tensor]:
|
49 |
+
if isinstance(prompt, str):
|
50 |
+
prompt = [prompt]
|
51 |
+
t5_tokens = t5.tokenize(prompt)
|
52 |
+
clip_tokens = clip.tokenize(prompt)
|
53 |
+
return t5_tokens, clip_tokens
|
54 |
+
# return {
|
55 |
+
# "t5": t5_tokens,
|
56 |
+
# "clip": clip_tokens,
|
57 |
+
# }
|
58 |
+
|
59 |
+
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, t5_tokens: Tensor, clip_tokens: Tensor) -> dict[str, Tensor]:
|
60 |
+
# bs, c, h, w = img.shape
|
61 |
+
bs, h, w, c = img.shape
|
62 |
+
|
63 |
+
if bs == 1:
|
64 |
+
bs = t5_tokens.shape[0]
|
65 |
+
|
66 |
+
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
67 |
+
img = rearrange(img, "b (h ph) (w pw) c -> b (h w) (c ph pw)", ph=2, pw=2)
|
68 |
+
if img.shape[0] == 1 and bs > 1:
|
69 |
+
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
70 |
+
|
71 |
+
# img_ids = torch.zeros(h // 2, w // 2, 3)
|
72 |
+
img_ids = jnp.zeros((h // 2, w // 2, 3))
|
73 |
+
# img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
74 |
+
# img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
75 |
+
img_ids = img_ids.at[..., 1].set(img_ids[..., 1]+jnp.arange(h // 2)[:, None])
|
76 |
+
img_ids = img_ids.at[..., 2].set(img_ids[..., 2]+jnp.arange(w // 2)[None, :])
|
77 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
78 |
+
|
79 |
+
# if isinstance(prompt, str):
|
80 |
+
# prompt = [prompt]
|
81 |
+
txt = t5(t5_tokens)
|
82 |
+
if txt.shape[0] == 1 and bs > 1:
|
83 |
+
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
84 |
+
# txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
85 |
+
txt_ids = jnp.zeros((bs, txt.shape[1], 3))
|
86 |
+
|
87 |
+
vec = clip(clip_tokens)
|
88 |
+
if vec.shape[0] == 1 and bs > 1:
|
89 |
+
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
90 |
+
|
91 |
+
# return {
|
92 |
+
# "img": img,
|
93 |
+
# "img_ids": img_ids.to(img.device),
|
94 |
+
# "txt": txt.to(img.device),
|
95 |
+
# "txt_ids": txt_ids.to(img.device),
|
96 |
+
# "vec": vec.to(img.device),
|
97 |
+
# }
|
98 |
+
return {
|
99 |
+
"img": img,
|
100 |
+
"img_ids": img_ids,
|
101 |
+
"txt": txt,
|
102 |
+
"txt_ids": txt_ids,
|
103 |
+
"vec": vec,
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
108 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
109 |
+
# return jnp.exp(mu) / (jnp.exp(mu) + (1 / t - 1) ** sigma)
|
110 |
+
|
111 |
+
|
112 |
+
def get_lin_function(
|
113 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
114 |
+
) -> Callable[[float], float]:
|
115 |
+
m = (y2 - y1) / (x2 - x1)
|
116 |
+
b = y1 - m * x1
|
117 |
+
return lambda x: m * x + b
|
118 |
+
|
119 |
+
|
120 |
+
def get_schedule(
|
121 |
+
num_steps: int,
|
122 |
+
image_seq_len: int,
|
123 |
+
base_shift: float = 0.5,
|
124 |
+
max_shift: float = 1.15,
|
125 |
+
shift: bool = True,
|
126 |
+
) -> Tensor:
|
127 |
+
# extra step for zero
|
128 |
+
# timesteps = torch.linspace(1, 0, num_steps + 1)
|
129 |
+
timesteps = jnp.linspace(1, 0, num_steps + 1)
|
130 |
+
|
131 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
132 |
+
if shift:
|
133 |
+
# estimate mu based on linear estimation between two points
|
134 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
135 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
136 |
+
|
137 |
+
return timesteps#.tolist()
|
138 |
+
|
139 |
+
DEBUG=False
|
140 |
+
|
141 |
+
def denoise_for(
|
142 |
+
model: Flux,
|
143 |
+
# model input
|
144 |
+
img: Tensor,
|
145 |
+
img_ids: Tensor,
|
146 |
+
txt: Tensor,
|
147 |
+
txt_ids: Tensor,
|
148 |
+
vec: Tensor,
|
149 |
+
# sampling parameters
|
150 |
+
timesteps: Tensor,
|
151 |
+
guidance: float = 4.0,
|
152 |
+
):
|
153 |
+
# this is ignored for schnell
|
154 |
+
# guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
155 |
+
guidance_vec = jnp.full((img.shape[0],), guidance, dtype=img.dtype)
|
156 |
+
timesteps = timesteps.tolist()
|
157 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
158 |
+
# t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
159 |
+
t_vec = jnp.full((img.shape[0],), t_curr, dtype=img.dtype)
|
160 |
+
pred = model(
|
161 |
+
img=img,
|
162 |
+
img_ids=img_ids,
|
163 |
+
txt=txt,
|
164 |
+
txt_ids=txt_ids,
|
165 |
+
y=vec,
|
166 |
+
timesteps=t_vec,
|
167 |
+
guidance=guidance_vec,
|
168 |
+
)
|
169 |
+
|
170 |
+
img = img + (t_prev - t_curr) * pred
|
171 |
+
return img
|
172 |
+
|
173 |
+
|
174 |
+
# @nnx.jit
|
175 |
+
def denoise(
|
176 |
+
model: Flux,
|
177 |
+
# model input
|
178 |
+
img: Tensor,
|
179 |
+
img_ids: Tensor,
|
180 |
+
txt: Tensor,
|
181 |
+
txt_ids: Tensor,
|
182 |
+
vec: Tensor,
|
183 |
+
# sampling parameters
|
184 |
+
timesteps: Tensor,
|
185 |
+
guidance: float = 4.0,
|
186 |
+
):
|
187 |
+
# this is ignored for schnell
|
188 |
+
# guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
189 |
+
guidance_vec = jnp.full((img.shape[0],), guidance, dtype=img.dtype)
|
190 |
+
@nnx.scan
|
191 |
+
def scan_func(acc, t_prev):
|
192 |
+
img, t_curr = acc
|
193 |
+
dtype = img.dtype
|
194 |
+
t_vec = jnp.full((img.shape[0],), t_curr, dtype=img.dtype)
|
195 |
+
pred = model(
|
196 |
+
img=img,
|
197 |
+
img_ids=img_ids,
|
198 |
+
txt=txt,
|
199 |
+
txt_ids=txt_ids,
|
200 |
+
y=vec,
|
201 |
+
timesteps=t_vec,
|
202 |
+
guidance=guidance_vec,
|
203 |
+
)
|
204 |
+
|
205 |
+
img = img + (t_prev - t_curr) * pred
|
206 |
+
return (img.astype(dtype), t_prev), pred
|
207 |
+
acc,pred=scan_func((img, timesteps[0]), timesteps[1:])
|
208 |
+
return acc[0]
|
209 |
+
|
210 |
+
|
211 |
+
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
212 |
+
# return rearrange(
|
213 |
+
# x,
|
214 |
+
# "b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
215 |
+
# h=math.ceil(height / 16),
|
216 |
+
# w=math.ceil(width / 16),
|
217 |
+
# ph=2,
|
218 |
+
# pw=2,
|
219 |
+
# )
|
220 |
+
return rearrange(
|
221 |
+
x,
|
222 |
+
"b (h w) (c ph pw) -> b (h ph) (w pw) c",
|
223 |
+
h=math.ceil(height / 16),
|
224 |
+
w=math.ceil(width / 16),
|
225 |
+
ph=2,
|
226 |
+
pw=2,
|
227 |
+
)
|
flux/util.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import jax
|
6 |
+
from jax import Array as Tensor
|
7 |
+
import jax.numpy as jnp
|
8 |
+
from flax import nnx
|
9 |
+
import torch
|
10 |
+
from einops import rearrange
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
from imwatermark import WatermarkEncoder
|
13 |
+
from safetensors.torch import load_file as load_sft
|
14 |
+
|
15 |
+
from flux.model import Flux, FluxParams
|
16 |
+
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
17 |
+
from flux.modules.conditioner import HFEmbedder
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class ModelSpec:
|
24 |
+
params: FluxParams
|
25 |
+
ae_params: AutoEncoderParams
|
26 |
+
ckpt_path: str | None
|
27 |
+
ae_path: str | None
|
28 |
+
repo_id: str | None
|
29 |
+
repo_flow: str | None
|
30 |
+
repo_ae: str | None
|
31 |
+
|
32 |
+
|
33 |
+
configs = {
|
34 |
+
"flux-dev": ModelSpec(
|
35 |
+
repo_id="black-forest-labs/FLUX.1-dev",
|
36 |
+
repo_flow="flux1-dev.safetensors",
|
37 |
+
repo_ae="ae.safetensors",
|
38 |
+
ckpt_path=os.getenv("FLUX_DEV"),
|
39 |
+
params=FluxParams(
|
40 |
+
in_channels=64,
|
41 |
+
vec_in_dim=768,
|
42 |
+
context_in_dim=4096,
|
43 |
+
hidden_size=3072,
|
44 |
+
mlp_ratio=4.0,
|
45 |
+
num_heads=24,
|
46 |
+
depth=19,
|
47 |
+
depth_single_blocks=38,
|
48 |
+
axes_dim=[16, 56, 56],
|
49 |
+
theta=10_000,
|
50 |
+
qkv_bias=True,
|
51 |
+
guidance_embed=True,
|
52 |
+
),
|
53 |
+
ae_path=os.getenv("AE"),
|
54 |
+
ae_params=AutoEncoderParams(
|
55 |
+
resolution=256,
|
56 |
+
in_channels=3,
|
57 |
+
ch=128,
|
58 |
+
out_ch=3,
|
59 |
+
ch_mult=[1, 2, 4, 4],
|
60 |
+
num_res_blocks=2,
|
61 |
+
z_channels=16,
|
62 |
+
scale_factor=0.3611,
|
63 |
+
shift_factor=0.1159,
|
64 |
+
),
|
65 |
+
),
|
66 |
+
"flux-schnell": ModelSpec(
|
67 |
+
repo_id="black-forest-labs/FLUX.1-schnell",
|
68 |
+
repo_flow="flux1-schnell.safetensors",
|
69 |
+
repo_ae="ae.safetensors",
|
70 |
+
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
71 |
+
params=FluxParams(
|
72 |
+
in_channels=64,
|
73 |
+
vec_in_dim=768,
|
74 |
+
context_in_dim=4096,
|
75 |
+
hidden_size=3072,
|
76 |
+
mlp_ratio=4.0,
|
77 |
+
num_heads=24,
|
78 |
+
depth=19,
|
79 |
+
depth_single_blocks=38,
|
80 |
+
axes_dim=[16, 56, 56],
|
81 |
+
theta=10_000,
|
82 |
+
qkv_bias=True,
|
83 |
+
guidance_embed=False,
|
84 |
+
),
|
85 |
+
ae_path=os.getenv("AE"),
|
86 |
+
ae_params=AutoEncoderParams(
|
87 |
+
resolution=256,
|
88 |
+
in_channels=3,
|
89 |
+
ch=128,
|
90 |
+
out_ch=3,
|
91 |
+
ch_mult=[1, 2, 4, 4],
|
92 |
+
num_res_blocks=2,
|
93 |
+
z_channels=16,
|
94 |
+
scale_factor=0.3611,
|
95 |
+
shift_factor=0.1159,
|
96 |
+
),
|
97 |
+
),
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
try:
|
102 |
+
import ml_dtypes
|
103 |
+
from_torch_bf16 = lambda x: jnp.asarray(x.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16))
|
104 |
+
except:
|
105 |
+
from_torch_bf16 = lambda x: jnp.asarray(x.float().numpy()).astype(jnp.bfloat16)
|
106 |
+
|
107 |
+
def load_from_torch(graph, state, state_dict:dict):
|
108 |
+
cnt=0
|
109 |
+
torch_cnt=0
|
110 |
+
flax_cnt=0
|
111 |
+
val_cnt=0
|
112 |
+
print(f"Torch states: #{len(state_dict)}; Flax states: #{len(state.flat_state())}")
|
113 |
+
def convert_to_jax(tensor):
|
114 |
+
if tensor.dtype==torch.bfloat16:
|
115 |
+
return from_torch_bf16(tensor)
|
116 |
+
else:
|
117 |
+
return jnp.asarray(tensor.numpy())
|
118 |
+
for key in sorted(state_dict.keys()):
|
119 |
+
ptr=state
|
120 |
+
node=graph
|
121 |
+
torch_cnt+=1
|
122 |
+
# print(key)
|
123 |
+
try:
|
124 |
+
for loc in key.split(".")[:-1]:
|
125 |
+
if loc.isnumeric():
|
126 |
+
if "layers" in ptr:
|
127 |
+
ptr=ptr["layers"]
|
128 |
+
node=node.subgraphs["layers"]
|
129 |
+
loc=int(loc)
|
130 |
+
ptr=ptr[loc]
|
131 |
+
node=node.subgraphs[loc]
|
132 |
+
last=key.split(".")[-1]
|
133 |
+
if last not in ptr._mapping.keys():
|
134 |
+
ptr_keys=list(ptr._mapping.keys())
|
135 |
+
ptr_keys=list(filter(lambda x:x!="bias", ptr_keys))
|
136 |
+
if len(ptr_keys)==1:
|
137 |
+
ptr_key=ptr_keys[0]
|
138 |
+
elif last=="weight" and "kernel" in ptr_keys:
|
139 |
+
ptr_key="kernel"
|
140 |
+
else:
|
141 |
+
cnt+=1
|
142 |
+
raise Exception(f"Mismatched: {key}: {ptr_keys} ")
|
143 |
+
val=ptr[ptr_key].value
|
144 |
+
# assert state_dict[key].shape==val.shape, f"[{node.type}]mismatched {state_dict[key].shape} {val.shape}"
|
145 |
+
else:
|
146 |
+
if isinstance(ptr[last], jax.Array):
|
147 |
+
val=ptr[last]
|
148 |
+
else:
|
149 |
+
val=ptr[last].value
|
150 |
+
ptr_key=last
|
151 |
+
assert state_dict[key].shape==val.shape, f"{key} mismatched"
|
152 |
+
|
153 |
+
if isinstance(ptr[ptr_key], jax.Array):
|
154 |
+
assert state_dict[key].shape==val.shape, f"Array: [{node.type}]mismatched {state_dict[key].shape} {val.shape}"
|
155 |
+
kernel=convert_to_jax(state_dict[key])
|
156 |
+
val_cnt+=1
|
157 |
+
continue
|
158 |
+
elif ptr_key=="bias":
|
159 |
+
assert state_dict[key].shape==val.shape, f"Bias: [{node.type}]mismatched {state_dict[key].shape} {val.shape}"
|
160 |
+
kernel=nnx.Param(convert_to_jax(state_dict[key])).to_state()
|
161 |
+
else:
|
162 |
+
# print(node.type,node.attributes, )
|
163 |
+
# print(type(ptr._mapping[ptr_key]))
|
164 |
+
if 'kernel_size' in node.attributes:
|
165 |
+
kernel=convert_to_jax(state_dict[key])
|
166 |
+
# print(len(kernel.shape))
|
167 |
+
# print(kernel.shape)
|
168 |
+
if len(kernel.shape)==3:
|
169 |
+
kernel=jnp.transpose(kernel, (2, 1, 0))
|
170 |
+
elif len(kernel.shape)==4:
|
171 |
+
kernel=jnp.transpose(kernel, (2, 3, 1, 0))
|
172 |
+
elif len(kernel.shape)==5:
|
173 |
+
kernel=jnp.transpose(kernel, (2, 3, 4, 1, 0))
|
174 |
+
elif 'dot_general' in node.attributes:
|
175 |
+
kernel=convert_to_jax(state_dict[key])
|
176 |
+
kernel=jnp.transpose(kernel, (1, 0))
|
177 |
+
else:
|
178 |
+
# val=ptr[ptr_key].value
|
179 |
+
kernel=convert_to_jax(state_dict[key])
|
180 |
+
assert val.shape==kernel.shape, f"[{node.type}]mismatched {val.shape} {kernel.shape}"
|
181 |
+
kernel=nnx.Param(kernel).to_state()
|
182 |
+
# print("new", len(kernel.value.shape), type(kernel))
|
183 |
+
ptr._mapping[ptr_key]=kernel
|
184 |
+
flax_cnt+=1
|
185 |
+
except Exception as e:
|
186 |
+
print(e, f"{key}")
|
187 |
+
print(cnt, torch_cnt, flax_cnt, val_cnt)
|
188 |
+
# print(len(state.flat_state()))
|
189 |
+
return state
|
190 |
+
|
191 |
+
def load_state_dict(model, state_dict):
|
192 |
+
graph,state=nnx.split(model)
|
193 |
+
state=load_from_torch(graph, state, state_dict)
|
194 |
+
nnx.update(model, state)
|
195 |
+
return model
|
196 |
+
|
197 |
+
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
198 |
+
if len(missing) > 0 and len(unexpected) > 0:
|
199 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
200 |
+
print("\n" + "-" * 79 + "\n")
|
201 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
202 |
+
elif len(missing) > 0:
|
203 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
204 |
+
elif len(unexpected) > 0:
|
205 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
206 |
+
|
207 |
+
|
208 |
+
def patch_dtype(model,dtype,patch_param=False):
|
209 |
+
for path, module in model.iter_modules():
|
210 |
+
if hasattr(module, "dtype") and (module.dtype is None or jnp.issubdtype(module.dtype, jnp.floating)):
|
211 |
+
module.dtype=dtype
|
212 |
+
if patch_param:
|
213 |
+
if hasattr(module, "param_dtype") and jnp.issubdtype(module.param_dtype, jnp.floating):
|
214 |
+
module.param_dtype=dtype
|
215 |
+
if not patch_param:
|
216 |
+
return model
|
217 |
+
for path, parent in nnx.iter_graph(model):
|
218 |
+
if isinstance(parent, nnx.Module):
|
219 |
+
for name, value in vars(parent).items():
|
220 |
+
if isinstance(value, nnx.Variable) and value.value is None:
|
221 |
+
pass
|
222 |
+
# print(name)
|
223 |
+
elif isinstance(value, nnx.Variable):
|
224 |
+
if jnp.issubdtype(value.value.dtype, jnp.floating):
|
225 |
+
value.value = value.value.astype(dtype)
|
226 |
+
# print(name,value.value.dtype,value.dtype)
|
227 |
+
elif isinstance(value,jax.Array):
|
228 |
+
# print(name,value.dtype)
|
229 |
+
# print(parent.__getattribute__(name).dtype)
|
230 |
+
if jnp.issubdtype(value.dtype, jnp.floating):
|
231 |
+
parent.__setattr__(name,value.astype(dtype))
|
232 |
+
return model
|
233 |
+
|
234 |
+
|
235 |
+
def load_flow_model(name: str, device: str = "none", hf_download: bool = True):
|
236 |
+
# Loading Flux
|
237 |
+
print("Init model")
|
238 |
+
ckpt_path = configs[name].ckpt_path
|
239 |
+
if (
|
240 |
+
ckpt_path is None
|
241 |
+
and configs[name].repo_id is not None
|
242 |
+
and configs[name].repo_flow is not None
|
243 |
+
and hf_download
|
244 |
+
):
|
245 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
246 |
+
|
247 |
+
# with torch.device("meta" if ckpt_path is not None else device):
|
248 |
+
model = Flux(configs[name].params, dtype=jnp.bfloat16, rngs=nnx.Rngs(0))
|
249 |
+
model = patch_dtype(model, jnp.bfloat16)
|
250 |
+
if ckpt_path is not None:
|
251 |
+
print("Loading checkpoint")
|
252 |
+
# load_sft doesn't support torch.device
|
253 |
+
sd = load_sft(ckpt_path, device="cpu")
|
254 |
+
# TODO: loading state_dict
|
255 |
+
model = load_state_dict(model, sd)
|
256 |
+
# missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
257 |
+
# print_load_warning(missing, unexpected)
|
258 |
+
return model
|
259 |
+
|
260 |
+
|
261 |
+
def load_t5(device: str = "none", max_length: int = 512) -> HFEmbedder:
|
262 |
+
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
263 |
+
return HFEmbedder("lnyan/t5-v1_1-xxl-encoder", max_length=max_length, dtype=jnp.bfloat16)
|
264 |
+
|
265 |
+
|
266 |
+
def load_clip(device: str = "none") -> HFEmbedder:
|
267 |
+
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, dtype=jnp.bfloat16)
|
268 |
+
|
269 |
+
|
270 |
+
def load_ae(name: str, device: str = "none", hf_download: bool = True) -> AutoEncoder:
|
271 |
+
ckpt_path = configs[name].ae_path
|
272 |
+
if (
|
273 |
+
ckpt_path is None
|
274 |
+
and configs[name].repo_id is not None
|
275 |
+
and configs[name].repo_ae is not None
|
276 |
+
and hf_download
|
277 |
+
):
|
278 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
|
279 |
+
|
280 |
+
# Loading the autoencoder
|
281 |
+
print("Init AE")
|
282 |
+
# with torch.device("meta" if ckpt_path is not None else device):
|
283 |
+
ae = AutoEncoder(configs[name].ae_params, dtype=jnp.bfloat16, rngs=nnx.Rngs(0))
|
284 |
+
ae = patch_dtype(ae, jnp.bfloat16)
|
285 |
+
|
286 |
+
if ckpt_path is not None:
|
287 |
+
sd = load_sft(ckpt_path, device="cpu")
|
288 |
+
# TODO: loading state_dict
|
289 |
+
ae = load_state_dict(ae, sd)
|
290 |
+
# missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
291 |
+
# print_load_warning(missing, unexpected)
|
292 |
+
return ae
|
293 |
+
|
294 |
+
|
295 |
+
class WatermarkEmbedder:
|
296 |
+
def __init__(self, watermark):
|
297 |
+
self.watermark = watermark
|
298 |
+
self.num_bits = len(WATERMARK_BITS)
|
299 |
+
self.encoder = WatermarkEncoder()
|
300 |
+
self.encoder.set_watermark("bits", self.watermark)
|
301 |
+
|
302 |
+
def __call__(self, image: Tensor) -> Tensor:
|
303 |
+
"""
|
304 |
+
Adds a predefined watermark to the input image
|
305 |
+
|
306 |
+
Args:
|
307 |
+
image: ([N,] B, RGB, H, W) in range [-1, 1]
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
same as input but watermarked
|
311 |
+
"""
|
312 |
+
image = 0.5 * image + 0.5
|
313 |
+
squeeze = len(image.shape) == 4
|
314 |
+
if squeeze:
|
315 |
+
image = image[None, ...]
|
316 |
+
n = image.shape[0]
|
317 |
+
# image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
|
318 |
+
image_np = np.array(rearrange((255 * image), "n b h w c -> (n b) h w c"))[:, :, :, ::-1]
|
319 |
+
|
320 |
+
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
321 |
+
# watermarking libary expects input as cv2 BGR format
|
322 |
+
for k in range(image_np.shape[0]):
|
323 |
+
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
324 |
+
# image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
|
325 |
+
# image.device
|
326 |
+
# )
|
327 |
+
image = jnp.asarray(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b h w c", n=n))
|
328 |
+
# image = torch.clamp(image / 255, min=0.0, max=1.0)
|
329 |
+
image = jnp.clip(image / 255, min=0.0, max=1.0)
|
330 |
+
if squeeze:
|
331 |
+
image = image[0]
|
332 |
+
image = 2 * image - 1
|
333 |
+
return image
|
334 |
+
|
335 |
+
|
336 |
+
# A fixed 48-bit message that was chosen at random
|
337 |
+
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
|
338 |
+
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
339 |
+
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
340 |
+
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
flux/wrapper.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Lnyan (https://github.com/lkwq007). All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import jax
|
21 |
+
import jax.numpy as jnp
|
22 |
+
from jax import Array as Tensor
|
23 |
+
import flax
|
24 |
+
from flax import nnx
|
25 |
+
import flax.linen
|
26 |
+
|
27 |
+
|
28 |
+
def fake_init(key, feature_shape, param_dtype):
|
29 |
+
return jax.ShapeDtypeStruct(feature_shape, param_dtype)
|
30 |
+
|
31 |
+
|
32 |
+
def wrap_LayerNorm(dim, *, eps=1e-5, elementwise_affine=True, bias=True, rngs:nnx.Rngs):
|
33 |
+
return nnx.LayerNorm(dim, epsilon=eps, use_bias=elementwise_affine and bias, use_scale=elementwise_affine, bias_init=fake_init, scale_init=fake_init, rngs=rngs)
|
34 |
+
|
35 |
+
def wrap_Linear(dim, inner_dim, *, bias=True, rngs:nnx.Rngs):
|
36 |
+
return nnx.Linear(dim, inner_dim, use_bias=bias, kernel_init=fake_init, bias_init=fake_init, rngs=rngs)
|
37 |
+
|
38 |
+
|
39 |
+
def wrap_GroupNorm(num_groups, num_channels, *, eps=1e-5, affine=True, rngs:nnx.Rngs):
|
40 |
+
return nnx.GroupNorm(num_channels, num_groups=num_groups, epsilon=eps, use_bias=affine, use_scale=affine, bias_init=fake_init, scale_init=fake_init, rngs=rngs)
|
41 |
+
|
42 |
+
def wrap_Conv(in_channels, out_channels, kernel_size, *, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', rngs:nnx.Rngs, conv_dim:int):
|
43 |
+
if isinstance(kernel_size, int):
|
44 |
+
kernel_tuple = (kernel_size,) * conv_dim
|
45 |
+
else:
|
46 |
+
# elif isinstance(kernel_size, tuple):
|
47 |
+
assert len(kernel_size) == conv_dim
|
48 |
+
kernel_tuple = kernel_size
|
49 |
+
return nnx.Conv(in_channels, out_channels, kernel_tuple, strides=stride, padding=padding, use_bias=bias, kernel_init=fake_init, bias_init=fake_init, rngs=rngs)
|
50 |
+
# return nnx.Conv(in_channels, out_channels, kernel_tuple, stride=stride, padding=padding, dilation=dilation, feature_group_count=groups, use_bias=bias, rngs=rngs)
|
51 |
+
|
52 |
+
|
53 |
+
class nn_GELU(nnx.Module):
|
54 |
+
def __init__(self, approximate="none") -> None:
|
55 |
+
self.approximate=approximate=="tanh"
|
56 |
+
|
57 |
+
def __call__(self, x):
|
58 |
+
return nnx.gelu(x, approximate=self.approximate)
|
59 |
+
|
60 |
+
class nn_SiLU(nnx.Module):
|
61 |
+
def __init__(self) -> None:
|
62 |
+
pass
|
63 |
+
|
64 |
+
def __call__(self, x):
|
65 |
+
return nnx.silu(x)
|
66 |
+
|
67 |
+
class nn_AvgPool(nnx.Module):
|
68 |
+
def __init__(self, window_shape, strides=None, padding="VALID") -> None:
|
69 |
+
self.window_shape=window_shape
|
70 |
+
self.strides=strides
|
71 |
+
self.padding=padding
|
72 |
+
|
73 |
+
def __call__(self, x):
|
74 |
+
return flax.linen.avg_pool(x, window_shape=self.window_shape, strides=self.strides, padding=self.padding)
|
75 |
+
|
76 |
+
|
77 |
+
# a wrapper class
|
78 |
+
class TorchWrapper:
|
79 |
+
def __init__(self, rngs: nnx.Rngs, dtype=jnp.float32):
|
80 |
+
self.rngs = rngs
|
81 |
+
self.dtype = dtype
|
82 |
+
|
83 |
+
def declare_with_rng(self, *args):
|
84 |
+
ret=list(map(lambda f: partial(f, dtype=self.dtype, rngs=self.rngs), args))
|
85 |
+
return ret if len(ret)>1 else ret[0]
|
86 |
+
|
87 |
+
def conv_nd(self, dims, *args, **kwargs):
|
88 |
+
return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=dims)
|
89 |
+
|
90 |
+
def avg_pool(self, *args, **kwargs):
|
91 |
+
return nn_AvgPool(*args, **kwargs)
|
92 |
+
|
93 |
+
|
94 |
+
def linear(self, *args, **kwargs):
|
95 |
+
return self.Linear(*args, **kwargs)
|
96 |
+
|
97 |
+
def SiLU(self):
|
98 |
+
return nn_SiLU()
|
99 |
+
|
100 |
+
def GELU(self, approximate="none"):
|
101 |
+
return nn_GELU(approximate)
|
102 |
+
|
103 |
+
def Identity(self):
|
104 |
+
return lambda x: x
|
105 |
+
|
106 |
+
def LayerNorm(self, dim, eps=1e-5, elementwise_affine=True, bias=True):
|
107 |
+
return wrap_LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, bias=bias, rngs=self.rngs)
|
108 |
+
|
109 |
+
def GroupNorm(self, *args, **kwargs):
|
110 |
+
return wrap_GroupNorm(*args,**kwargs, rngs=self.rngs)
|
111 |
+
|
112 |
+
def Linear(self, *args, **kwargs):
|
113 |
+
return wrap_Linear(*args, **kwargs, rngs=self.rngs)
|
114 |
+
|
115 |
+
def Parameter(self, value):
|
116 |
+
return nnx.Param(value)
|
117 |
+
|
118 |
+
def Dropout(self, p):
|
119 |
+
return nnx.Dropout(rate=p, rngs=self.rngs)
|
120 |
+
|
121 |
+
def Sequential(self, *args):
|
122 |
+
return nnx.Sequential(*args)
|
123 |
+
|
124 |
+
def Conv1d(self, *args, **kwargs):
|
125 |
+
return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=1)
|
126 |
+
|
127 |
+
def Conv2d(self, *args, **kwargs):
|
128 |
+
return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=2)
|
129 |
+
|
130 |
+
def Conv3d(self, *args, **kwargs):
|
131 |
+
return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=3)
|
132 |
+
|
133 |
+
def ModuleList(self, lst=None):
|
134 |
+
if lst is None:
|
135 |
+
return []
|
136 |
+
return list(lst)
|
137 |
+
|
138 |
+
def Module(self,*args,**kwargs):
|
139 |
+
return nnx.Dict()
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
jax[cuda12]
|
2 |
+
flax==0.9.0
|
3 |
+
flash_attn_jax
|
4 |
+
torch
|
5 |
+
torchvision
|
6 |
+
opencv-python-headless
|
7 |
+
einops
|
8 |
+
huggingface_hub
|
9 |
+
transformers
|
10 |
+
tokenizers
|
11 |
+
sentencepiece
|
12 |
+
fire
|
13 |
+
invisible-watermark
|
14 |
+
ml-dtypes
|