lnyan commited on
Commit
d4607d7
1 Parent(s): f9616ba
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from io import BytesIO
4
+ import uuid
5
+
6
+ import torch
7
+ import gradio as gr
8
+ import spaces
9
+ import numpy as np
10
+ from einops import rearrange
11
+ from PIL import Image, ExifTags
12
+
13
+ from dataclasses import dataclass
14
+
15
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack, prepare_tokens
16
+ from flux.util import configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5
17
+
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from flax import nnx
22
+ from jax import Array as Tensor
23
+ from einops import repeat
24
+
25
+ @dataclass
26
+ class SamplingOptions:
27
+ prompt: str
28
+ width: int
29
+ height: int
30
+ num_steps: int
31
+ guidance: float
32
+ seed: int | None
33
+
34
+ NSFW_THRESHOLD = 0.85
35
+
36
+ @spaces.GPU
37
+ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
38
+ t5 = load_t5(device, max_length=256 if is_schnell else 512)
39
+ clip = load_clip(device)
40
+ model = load_flow_model(name, device="cpu" if offload else device)
41
+ ae = load_ae(name, device="cpu" if offload else device)
42
+ # nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
43
+ # return model, ae, t5, clip, nsfw_classifier
44
+ return nnx.split(model), nnx.split(ae), nnx.split(t5), t5.tokenizer, nnx.split(clip), clip.tokenizer, None
45
+
46
+ @jax.jit
47
+ def encode(ae,x):
48
+ ae=nnx.merge(*ae)
49
+ return ae.encode(x)
50
+
51
+ def _generate(model, ae, t5, clip, x, t5_tokens, clip_tokens, num_steps, guidance,
52
+ #init_image=None,
53
+ #image2image_strength=0.0,
54
+ shift=True):
55
+ b,h,w,c=x.shape
56
+ model=nnx.merge(*model)
57
+ ae=nnx.merge(*ae)
58
+ t5=nnx.merge(*t5)
59
+ clip=nnx.merge(*clip)
60
+ timesteps = get_schedule(
61
+ num_steps,
62
+ x.shape[-1] * x.shape[-2] // 4,
63
+ shift=shift,
64
+ )
65
+ # if init_image is not None:
66
+ # t_idx = int((1 - image2image_strength) * num_steps)
67
+ # t = timesteps[t_idx]
68
+ # timesteps = timesteps[t_idx:]
69
+ # x = t * x + (1.0 - t) * init_image.astype(x.dtype)
70
+ inp = prepare(t5, clip, x, t5_tokens, clip_tokens)
71
+ x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
72
+ x = unpack(x.astype(jnp.float32), h*8, w*8)
73
+ x = ae.decode(x)
74
+ return x
75
+
76
+ generate=jax.jit(_generate, static_argnames=("num_steps","shift"))
77
+
78
+
79
+ def prepare_tokens(t5_tokenizer, clip_tokenizer, prompt: str | list[str]) -> tuple[Tensor, Tensor]:
80
+ if isinstance(prompt, str):
81
+ prompt = [prompt]
82
+ t5_tokens = t5_tokenizer(
83
+ prompt,
84
+ truncation=True,
85
+ max_length=512,
86
+ return_length=False,
87
+ return_overflowing_tokens=False,
88
+ padding="max_length",
89
+ return_tensors="jax",
90
+ )["input_ids"]
91
+ clip_tokens = clip_tokenizer(
92
+ prompt,
93
+ truncation=True,
94
+ max_length=77,
95
+ return_length=False,
96
+ return_overflowing_tokens=False,
97
+ padding="max_length",
98
+ return_tensors="jax",
99
+ )["input_ids"]
100
+ return t5_tokens, clip_tokens
101
+
102
+
103
+ class FluxGenerator:
104
+ def __init__(self, model_name: str, device: str, offload: bool):
105
+ self.device = None
106
+ self.offload = offload
107
+ self.model_name = model_name
108
+ self.is_schnell = model_name == "flux-schnell"
109
+ self.model, self.ae, self.t5, self.t5_tokenizer, self.clip, self.clip_tokenizer, self.nsfw_classifier = get_models(
110
+ model_name,
111
+ device=self.device,
112
+ offload=self.offload,
113
+ is_schnell=self.is_schnell,
114
+ )
115
+ self.key = jax.random.key(0)
116
+
117
+ @spaces.GPU(duration=180)
118
+ def generate_image(
119
+ self,
120
+ img_size,
121
+ num_steps,
122
+ guidance,
123
+ seed,
124
+ prompt,
125
+ # init_image=None,
126
+ # image2image_strength=0.0,
127
+ add_sampling_metadata=True,
128
+ ):
129
+ seed = int(seed)
130
+ if seed == -1:
131
+ seed = None
132
+ if img_size == "1,024x1,024":
133
+ width, height = 1024, 1024
134
+ else:
135
+ width, height = 512, 512
136
+
137
+ opts = SamplingOptions(
138
+ prompt=prompt,
139
+ width=width,
140
+ height=height,
141
+ num_steps=num_steps,
142
+ guidance=guidance,
143
+ seed=seed,
144
+ )
145
+
146
+ if opts.seed is None:
147
+ # opts.seed = torch.Generator(device="cpu").seed()
148
+ key,self.key=jax.random.split(self.key,2)
149
+ opts.seed=jax.random.randint(key,(),0,2**30)
150
+ print(f"Generating '{opts.prompt}' with seed {opts.seed}")
151
+ t0 = time.perf_counter()
152
+
153
+ # if init_image is not None:
154
+ # if isinstance(init_image, np.ndarray):
155
+ # init_image = jnp.asarray(init_image).astype(jnp.float32) / 255.0
156
+ # init_image = init_image[None]
157
+ # # init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width))
158
+ # init_image = jax.image.resize(init_image, (opts.height, opts.width), method="lanczos5")
159
+ # # if self.offload:
160
+ # # self.ae.encoder.to(self.device)
161
+ # # init_image = self.ae.encode(init_image)
162
+ # init_image = encode(self.ae, init_image)
163
+
164
+ # prepare input
165
+ t5_tokens, clip_tokens = prepare_tokens(self.t5_tokenizer, self.clip_tokenizer, prompt=opts.prompt)
166
+ x = get_noise(
167
+ 1,
168
+ opts.height,
169
+ opts.width,
170
+ device=None,
171
+ dtype=jnp.bfloat16,
172
+ seed=opts.seed,
173
+ )
174
+
175
+ x = generate(self.model, self.ae, self.t5, self.clip, x, t5_tokens, clip_tokens, opts.num_steps, opts.guidance, shift=(not self.is_schnell))
176
+
177
+ t1 = time.perf_counter()
178
+ # print(f"Done in {t1 - t0:.1f}s.")
179
+ runtime = t1 - t0
180
+ # print(f"Done in {t1 - t0:.1f}s.")
181
+ # bring into PIL format
182
+ x= jnp.clip(x, -1, 1)
183
+ # x = embed_watermark(x.astype(jnp.float32))
184
+ # x = rearrange(x[0], "c h w -> h w c")
185
+ img = Image.fromarray(np.asarray((127.5 * (x[0] + 1.0))).astype(np.uint8))
186
+ # img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
187
+ # nsfw_score = [x["score"] for x in self.nsfw_classifier(img) if x["label"] == "nsfw"][0]
188
+
189
+ if True:
190
+ filename = f"output/gradio/{uuid.uuid4()}.jpg"
191
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
192
+ exif_data = Image.Exif()
193
+ # if init_image is None:
194
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
195
+ # else:
196
+ # exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
197
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
198
+ exif_data[ExifTags.Base.Model] = self.model_name
199
+ if add_sampling_metadata:
200
+ exif_data[ExifTags.Base.ImageDescription] = prompt
201
+
202
+ img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
203
+
204
+ return img, runtime, str(opts.seed), filename, None
205
+ else:
206
+ return None, str(opts.seed), None, "Your generated image may contain NSFW content."
207
+
208
+ @spaces.GPU(duration=300)
209
+ def create_demo(model_name: str, device: str = "cuda", offload: bool = False):
210
+ generator = FluxGenerator(model_name, device, offload)
211
+ is_schnell = model_name == "flux-schnell"
212
+
213
+ with open("./assets/banner.html") as f:
214
+ banner = f.read()
215
+ with gr.Blocks() as demo:
216
+ with gr.Column(elem_id="app-container"):
217
+ gr.HTML(f"""<iframe scrolling="no" style="width: 100%; height: 125px; border: 0" srcdoc='{banner}'>""")
218
+ gr.Markdown(f"""🚀 [Flux-Flax](https://github.com/lkwq007/flux-flax) is a JAX implementation of Flux models. 1-step time statistics for `FLUX.1-schnell`: `0.4s` for 1024x1024, `0.1s` for 512x512; 2-step: `0.6s` for 1024x1024, `0.2s` for 512x512; 4-step: `2.4s` for 1024x1024, `0.8s` for 512x512.
219
+ """)
220
+
221
+ with gr.Row():
222
+ with gr.Column(scale=3):
223
+ output_image = gr.Image(label="Generated Image")
224
+ warning_text = gr.Textbox(label="Warning", visible=False)
225
+ download_btn = gr.File(label="Download full-resolution")
226
+ gr.Markdown("""
227
+ 💡 Note: More resolutions are supports, but here this demo limits to 1024x1024 and 512x512 to avoid jit recompilation (which takes 130s). Flux-Flax also support `FLUX.1-dev`, 50-step time statistics: `18s` for 1024x1024, `6s` for 512x512""")
228
+ with gr.Column(scale=1):
229
+ prompt = gr.Textbox(label="Prompt", value="a photo of a forest with mist swirling around the tree trunks. The word \"FLUX\" is painted over it in big, red brush strokes with visible texture")
230
+ generate_btn = gr.Button("Generate")
231
+ with gr.Row():
232
+ seed_output = gr.Number(label="Used Seed")
233
+ runtime = gr.Number(label="Inference Time", precision=3)
234
+ with gr.Row():
235
+ seed = gr.Textbox(-1, label="Seed (-1 for random)")
236
+ img_size = gr.Radio(["1,024x1,024", "512x512"], label="Image Resolution", value="1,024x1,024")
237
+ num_steps = gr.Slider(1, 4, 1, step=1, label="Number of steps")
238
+ add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True)
239
+ guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell, visible=False)
240
+
241
+
242
+
243
+
244
+ # def update_img2img(do_img2img):
245
+ # return {
246
+ # init_image: gr.update(visible=do_img2img),
247
+ # image2image_strength: gr.update(visible=do_img2img),
248
+ # }
249
+
250
+ # do_img2img.change(update_img2img, do_img2img, [init_image, image2image_strength])
251
+ generate_btn.click(
252
+ fn=generator.generate_image,
253
+ inputs=[img_size, num_steps, guidance, seed, prompt, add_sampling_metadata],
254
+ outputs=[output_image, runtime, seed_output, download_btn, warning_text],
255
+ )
256
+
257
+ return demo
258
+
259
+ # if __name__ == "__main__":
260
+ # import argparse
261
+ # parser = argparse.ArgumentParser(description="Flux")
262
+ # parser.add_argument("--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name")
263
+ # parser.add_argument("--device", type=str, default="cpu", help="Device to use")
264
+ # parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
265
+ # parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
266
+ # args = parser.parse_args()
267
+
268
+ demo = create_demo("flux-schnell", None, False)
269
+ demo.launch()
flux/__init__.py ADDED
File without changes
flux/math.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ import math
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from einops import rearrange
6
+ from flax import nnx
7
+
8
+ Tensor=jax.Array
9
+
10
+ def check_tpu():
11
+ return any('TPU' in d.device_kind for d in jax.devices())
12
+
13
+ # from torch import Tensor
14
+ if check_tpu():
15
+ from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
16
+ # q, # [batch_size, num_heads, q_seq_len, d_model]
17
+ # k, # [batch_size, num_heads, kv_seq_len, d_model]
18
+ # v, # [batch_size, num_heads, kv_seq_len, d_model]
19
+ def flash_mha(q, k, v):
20
+ return flash_attention(q, k, v, sm_scale=1/math.sqrt(q.shape[-1]))
21
+ else:
22
+ from jax.experimental.pallas.ops.gpu.attention import mha, mha_reference
23
+ def pallas_mha(q, k, v):
24
+ # B L H D
25
+ # return mha_reference(q,k,v,segment_ids=None,sm_scale=1/math.sqrt(q.shape[-1]))
26
+ q_len=q.shape[1]
27
+ diff=(-q_len)&127
28
+ segment_ids=jnp.zeros((q.shape[0],q.shape[1]),dtype=jnp.int32)
29
+ segment_ids=jnp.pad(segment_ids,((0,0),(0,diff)),mode="constant",constant_values=1)
30
+ # q,k,v=map(lambda x: jnp.pad(x,((0,0),(0,diff),(0,0),(0,0)),mode="constant", constant_values=0),(q,k,v))
31
+ return mha(q,k,v,segment_ids=segment_ids,sm_scale=1/math.sqrt(q.shape[-1]))#[:,:q_len]
32
+ # mha: batch_size, seq_len, num_heads, head_dim = q.shape
33
+ from functools import partial
34
+ from flux.modules.attention_flax import jax_memory_efficient_attention
35
+ try:
36
+ from flash_attn_jax import flash_mha
37
+ except:
38
+ flash_mha = pallas_mha
39
+ # flash_mha = nnx.dot_product_attention
40
+
41
+
42
+ def dot_product_attention(q, k, v, sm_scale=1.0):
43
+ q,k,v=map(lambda x: rearrange(x, "b h n d -> b n h d"), (q,k,v))
44
+ # ret = pallas_mha(q,k,v)
45
+ ret = nnx.dot_product_attention(q,k,v)
46
+ # if q.shape[-3] % 64 == 0:
47
+ # query_chunk_size = int(q.shape[-3] / 64)
48
+ # elif q.shape[-3] % 16 == 0:
49
+ # query_chunk_size = int(q.shape[-3] / 16)
50
+ # elif q.shape[-3] % 4 == 0:
51
+ # query_chunk_size = int(q.shape[-3] / 4)
52
+ # else:
53
+ # query_chunk_size = int(q.shape[-3])
54
+ # ret=jax_memory_efficient_attention(q, k, v, query_chunk_size=query_chunk_size)
55
+ return rearrange(ret, "b n h d -> b h n d")
56
+
57
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
58
+ q, k = apply_rope(q, k, pe)
59
+ # x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
60
+ # q is B H L D
61
+ q,k,v=map(lambda x: rearrange(x, "B H L D -> B L H D"), (q,k,v))
62
+ # x = nnx.dot_product_attention(q,k,v)
63
+ x = flash_mha(q,k,v)
64
+ # x = pallas_mha(q,k,v)
65
+ # x = mha(q,k,v,None,sm_scale=1/math.sqrt(q.shape[-1]))
66
+ x = rearrange(x, "B L H D -> B L (H D)")
67
+
68
+ # x = rearrange(x, "B H L D -> B L (H D)")
69
+
70
+ return x
71
+
72
+
73
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
74
+ assert dim % 2 == 0
75
+ # scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
76
+ scale = jnp.arange(0, dim, 2, dtype=jnp.float32) / dim
77
+ omega = 1.0 / (theta**scale)
78
+ # out = torch.einsum("...n,d->...nd", pos, omega)
79
+ out = jnp.einsum("...n,d->...nd", pos.astype(jnp.float32), omega)
80
+ # out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
81
+ out = jnp.stack([jnp.cos(out), -jnp.sin(out), jnp.sin(out), jnp.cos(out)], axis=-1)
82
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
83
+ # return out.float()
84
+ return out.astype(jnp.float32)
85
+
86
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
87
+ # xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
88
+ # xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
89
+ xq_ = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 1, 2)
90
+ xk_ = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 1, 2)
91
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
92
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
93
+ # return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
94
+ return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype)
flux/model.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ import jax.numpy as jnp
5
+ from jax import Array as Tensor
6
+ from flax import nnx
7
+
8
+ from flux.wrapper import TorchWrapper
9
+
10
+ from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
11
+ MLPEmbedder, SingleStreamBlock,
12
+ timestep_embedding)
13
+
14
+
15
+ @dataclass
16
+ class FluxParams:
17
+ in_channels: int
18
+ vec_in_dim: int
19
+ context_in_dim: int
20
+ hidden_size: int
21
+ mlp_ratio: float
22
+ num_heads: int
23
+ depth: int
24
+ depth_single_blocks: int
25
+ axes_dim: list[int]
26
+ theta: int
27
+ qkv_bias: bool
28
+ guidance_embed: bool
29
+
30
+
31
+ DoubleStreamBlock_class, EmbedND_class, LastLayer_class, MLPEmbedder_class, SingleStreamBlock_class = DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock
32
+
33
+ class Flux(nnx.Module):
34
+ """
35
+ Transformer model for flow matching on sequences.
36
+ """
37
+
38
+ def __init__(self, params: FluxParams, dtype: jnp.dtype = jnp.float32, rngs: nnx.Rngs = None):
39
+ nn = TorchWrapper(rngs=rngs, dtype=dtype)
40
+ DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock = nn.declare_with_rng(DoubleStreamBlock_class, EmbedND_class, LastLayer_class, MLPEmbedder_class, SingleStreamBlock_class)
41
+ self.params = params
42
+ self.in_channels = params.in_channels
43
+ self.out_channels = self.in_channels
44
+ if params.hidden_size % params.num_heads != 0:
45
+ raise ValueError(
46
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
47
+ )
48
+ pe_dim = params.hidden_size // params.num_heads
49
+ if sum(params.axes_dim) != pe_dim:
50
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
51
+ self.hidden_size = params.hidden_size
52
+ self.num_heads = params.num_heads
53
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
54
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
55
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
56
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
57
+ self.guidance_in = (
58
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
59
+ )
60
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
61
+
62
+ self.double_blocks = nn.ModuleList(
63
+ [
64
+ DoubleStreamBlock(
65
+ self.hidden_size,
66
+ self.num_heads,
67
+ mlp_ratio=params.mlp_ratio,
68
+ qkv_bias=params.qkv_bias,
69
+ )
70
+ for _ in range(params.depth)
71
+ ]
72
+ )
73
+
74
+ self.single_blocks = nn.ModuleList(
75
+ [
76
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
77
+ for _ in range(params.depth_single_blocks)
78
+ ]
79
+ )
80
+
81
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
82
+
83
+ def __call__(
84
+ self,
85
+ img: Tensor,
86
+ img_ids: Tensor,
87
+ txt: Tensor,
88
+ txt_ids: Tensor,
89
+ timesteps: Tensor,
90
+ y: Tensor,
91
+ guidance: Tensor | None = None,
92
+ ) -> Tensor:
93
+ if img.ndim != 3 or txt.ndim != 3:
94
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
95
+
96
+ # running on sequences img
97
+ img = self.img_in(img)
98
+ vec = self.time_in(timestep_embedding(timesteps, 256))
99
+ if self.params.guidance_embed:
100
+ if guidance is None:
101
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
102
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
103
+ vec = vec + self.vector_in(y)
104
+ txt = self.txt_in(txt)
105
+
106
+ # ids = torch.cat((txt_ids, img_ids), dim=1)
107
+ ids = jnp.concatenate((txt_ids, img_ids), axis=1)
108
+ pe = self.pe_embedder(ids)
109
+
110
+ for block in self.double_blocks:
111
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
112
+
113
+ # img = torch.cat((txt, img), 1)
114
+ img = jnp.concatenate((txt, img), axis=1)
115
+ for block in self.single_blocks:
116
+ img = block(img, vec=vec, pe=pe)
117
+ img = img[:, txt.shape[1] :, ...]
118
+
119
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
120
+ return img
flux/modules/attention_flax.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import math
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+
22
+
23
+ def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
24
+ """Multi-head dot product attention with a limited number of queries."""
25
+ num_kv, num_heads, k_features = key.shape[-3:]
26
+ v_features = value.shape[-1]
27
+ key_chunk_size = min(key_chunk_size, num_kv)
28
+ query = query / jnp.sqrt(k_features)
29
+
30
+ @functools.partial(jax.checkpoint, prevent_cse=False)
31
+ def summarize_chunk(query, key, value):
32
+ attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
33
+
34
+ max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
35
+ max_score = jax.lax.stop_gradient(max_score)
36
+ exp_weights = jnp.exp(attn_weights - max_score)
37
+
38
+ exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
39
+ max_score = jnp.einsum("...qhk->...qh", max_score)
40
+
41
+ return (exp_values, exp_weights.sum(axis=-1), max_score)
42
+
43
+ def chunk_scanner(chunk_idx):
44
+ # julienne key array
45
+ key_chunk = jax.lax.dynamic_slice(
46
+ operand=key,
47
+ start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
48
+ slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
49
+ )
50
+
51
+ # julienne value array
52
+ value_chunk = jax.lax.dynamic_slice(
53
+ operand=value,
54
+ start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
55
+ slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
56
+ )
57
+
58
+ return summarize_chunk(query, key_chunk, value_chunk)
59
+
60
+ chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
61
+
62
+ global_max = jnp.max(chunk_max, axis=0, keepdims=True)
63
+ max_diffs = jnp.exp(chunk_max - global_max)
64
+
65
+ chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
66
+ chunk_weights *= max_diffs
67
+
68
+ all_values = chunk_values.sum(axis=0)
69
+ all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
70
+
71
+ return all_values / all_weights
72
+
73
+
74
+ def jax_memory_efficient_attention(
75
+ query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
76
+ ):
77
+ r"""
78
+ Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
79
+ https://github.com/AminRezaei0x443/memory-efficient-attention
80
+
81
+ Args:
82
+ query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
83
+ key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
84
+ value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
85
+ precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
86
+ numerical precision for computation
87
+ query_chunk_size (`int`, *optional*, defaults to 1024):
88
+ chunk size to divide query array value must divide query_length equally without remainder
89
+ key_chunk_size (`int`, *optional*, defaults to 4096):
90
+ chunk size to divide key and value array value must divide key_value_length equally without remainder
91
+
92
+ Returns:
93
+ (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
94
+ """
95
+ num_q, num_heads, q_features = query.shape[-3:]
96
+
97
+ def chunk_scanner(chunk_idx, _):
98
+ # julienne query array
99
+ query_chunk = jax.lax.dynamic_slice(
100
+ operand=query,
101
+ start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
102
+ slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
103
+ )
104
+
105
+ return (
106
+ chunk_idx + query_chunk_size, # unused ignore it
107
+ _query_chunk_attention(
108
+ query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
109
+ ),
110
+ )
111
+
112
+ _, res = jax.lax.scan(
113
+ f=chunk_scanner,
114
+ init=0,
115
+ xs=None,
116
+ length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
117
+ )
118
+
119
+ return jnp.concatenate(res, axis=-3) # fuse the chunked result back
120
+
121
+
122
+ class FlaxAttention(nn.Module):
123
+ r"""
124
+ A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
125
+
126
+ Parameters:
127
+ query_dim (:obj:`int`):
128
+ Input hidden states dimension
129
+ heads (:obj:`int`, *optional*, defaults to 8):
130
+ Number of heads
131
+ dim_head (:obj:`int`, *optional*, defaults to 64):
132
+ Hidden states dimension inside each head
133
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
134
+ Dropout rate
135
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
136
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
137
+ split_head_dim (`bool`, *optional*, defaults to `False`):
138
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
139
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
140
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
141
+ Parameters `dtype`
142
+
143
+ """
144
+
145
+ query_dim: int
146
+ heads: int = 8
147
+ dim_head: int = 64
148
+ dropout: float = 0.0
149
+ use_memory_efficient_attention: bool = False
150
+ split_head_dim: bool = False
151
+ dtype: jnp.dtype = jnp.float32
152
+
153
+ def setup(self):
154
+ inner_dim = self.dim_head * self.heads
155
+ self.scale = self.dim_head**-0.5
156
+
157
+ # Weights were exported with old names {to_q, to_k, to_v, to_out}
158
+ self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
159
+ self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
160
+ self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
161
+
162
+ self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
163
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
164
+
165
+ def reshape_heads_to_batch_dim(self, tensor):
166
+ batch_size, seq_len, dim = tensor.shape
167
+ head_size = self.heads
168
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
169
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
170
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
171
+ return tensor
172
+
173
+ def reshape_batch_dim_to_heads(self, tensor):
174
+ batch_size, seq_len, dim = tensor.shape
175
+ head_size = self.heads
176
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
177
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
178
+ tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
179
+ return tensor
180
+
181
+ def __call__(self, hidden_states, context=None, deterministic=True):
182
+ context = hidden_states if context is None else context
183
+
184
+ query_proj = self.query(hidden_states)
185
+ key_proj = self.key(context)
186
+ value_proj = self.value(context)
187
+
188
+ if self.split_head_dim:
189
+ b = hidden_states.shape[0]
190
+ query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
191
+ key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
192
+ value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
193
+ else:
194
+ query_states = self.reshape_heads_to_batch_dim(query_proj)
195
+ key_states = self.reshape_heads_to_batch_dim(key_proj)
196
+ value_states = self.reshape_heads_to_batch_dim(value_proj)
197
+
198
+ if self.use_memory_efficient_attention:
199
+ query_states = query_states.transpose(1, 0, 2)
200
+ key_states = key_states.transpose(1, 0, 2)
201
+ value_states = value_states.transpose(1, 0, 2)
202
+
203
+ # this if statement create a chunk size for each layer of the unet
204
+ # the chunk size is equal to the query_length dimension of the deepest layer of the unet
205
+
206
+ flatten_latent_dim = query_states.shape[-3]
207
+ if flatten_latent_dim % 64 == 0:
208
+ query_chunk_size = int(flatten_latent_dim / 64)
209
+ elif flatten_latent_dim % 16 == 0:
210
+ query_chunk_size = int(flatten_latent_dim / 16)
211
+ elif flatten_latent_dim % 4 == 0:
212
+ query_chunk_size = int(flatten_latent_dim / 4)
213
+ else:
214
+ query_chunk_size = int(flatten_latent_dim)
215
+
216
+ hidden_states = jax_memory_efficient_attention(
217
+ query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
218
+ )
219
+
220
+ hidden_states = hidden_states.transpose(1, 0, 2)
221
+ else:
222
+ # compute attentions
223
+ if self.split_head_dim:
224
+ attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
225
+ else:
226
+ attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
227
+
228
+ attention_scores = attention_scores * self.scale
229
+ attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
230
+
231
+ # attend to values
232
+ if self.split_head_dim:
233
+ hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
234
+ b = hidden_states.shape[0]
235
+ hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
236
+ else:
237
+ hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
238
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
239
+
240
+ hidden_states = self.proj_attn(hidden_states)
241
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
242
+
243
+
244
+ class FlaxBasicTransformerBlock(nn.Module):
245
+ r"""
246
+ A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
247
+ https://arxiv.org/abs/1706.03762
248
+
249
+
250
+ Parameters:
251
+ dim (:obj:`int`):
252
+ Inner hidden states dimension
253
+ n_heads (:obj:`int`):
254
+ Number of heads
255
+ d_head (:obj:`int`):
256
+ Hidden states dimension inside each head
257
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
258
+ Dropout rate
259
+ only_cross_attention (`bool`, defaults to `False`):
260
+ Whether to only apply cross attention.
261
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
262
+ Parameters `dtype`
263
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
264
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
265
+ split_head_dim (`bool`, *optional*, defaults to `False`):
266
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
267
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
268
+ """
269
+
270
+ dim: int
271
+ n_heads: int
272
+ d_head: int
273
+ dropout: float = 0.0
274
+ only_cross_attention: bool = False
275
+ dtype: jnp.dtype = jnp.float32
276
+ use_memory_efficient_attention: bool = False
277
+ split_head_dim: bool = False
278
+
279
+ def setup(self):
280
+ # self attention (or cross_attention if only_cross_attention is True)
281
+ self.attn1 = FlaxAttention(
282
+ self.dim,
283
+ self.n_heads,
284
+ self.d_head,
285
+ self.dropout,
286
+ self.use_memory_efficient_attention,
287
+ self.split_head_dim,
288
+ dtype=self.dtype,
289
+ )
290
+ # cross attention
291
+ self.attn2 = FlaxAttention(
292
+ self.dim,
293
+ self.n_heads,
294
+ self.d_head,
295
+ self.dropout,
296
+ self.use_memory_efficient_attention,
297
+ self.split_head_dim,
298
+ dtype=self.dtype,
299
+ )
300
+ self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
301
+ self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
302
+ self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
303
+ self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
304
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
305
+
306
+ def __call__(self, hidden_states, context, deterministic=True):
307
+ # self attention
308
+ residual = hidden_states
309
+ if self.only_cross_attention:
310
+ hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
311
+ else:
312
+ hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
313
+ hidden_states = hidden_states + residual
314
+
315
+ # cross attention
316
+ residual = hidden_states
317
+ hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
318
+ hidden_states = hidden_states + residual
319
+
320
+ # feed forward
321
+ residual = hidden_states
322
+ hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
323
+ hidden_states = hidden_states + residual
324
+
325
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
326
+
327
+
328
+ class FlaxTransformer2DModel(nn.Module):
329
+ r"""
330
+ A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
331
+ https://arxiv.org/pdf/1506.02025.pdf
332
+
333
+
334
+ Parameters:
335
+ in_channels (:obj:`int`):
336
+ Input number of channels
337
+ n_heads (:obj:`int`):
338
+ Number of heads
339
+ d_head (:obj:`int`):
340
+ Hidden states dimension inside each head
341
+ depth (:obj:`int`, *optional*, defaults to 1):
342
+ Number of transformers block
343
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
344
+ Dropout rate
345
+ use_linear_projection (`bool`, defaults to `False`): tbd
346
+ only_cross_attention (`bool`, defaults to `False`): tbd
347
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
348
+ Parameters `dtype`
349
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
350
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
351
+ split_head_dim (`bool`, *optional*, defaults to `False`):
352
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
353
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
354
+ """
355
+
356
+ in_channels: int
357
+ n_heads: int
358
+ d_head: int
359
+ depth: int = 1
360
+ dropout: float = 0.0
361
+ use_linear_projection: bool = False
362
+ only_cross_attention: bool = False
363
+ dtype: jnp.dtype = jnp.float32
364
+ use_memory_efficient_attention: bool = False
365
+ split_head_dim: bool = False
366
+
367
+ def setup(self):
368
+ self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
369
+
370
+ inner_dim = self.n_heads * self.d_head
371
+ if self.use_linear_projection:
372
+ self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
373
+ else:
374
+ self.proj_in = nn.Conv(
375
+ inner_dim,
376
+ kernel_size=(1, 1),
377
+ strides=(1, 1),
378
+ padding="VALID",
379
+ dtype=self.dtype,
380
+ )
381
+
382
+ self.transformer_blocks = [
383
+ FlaxBasicTransformerBlock(
384
+ inner_dim,
385
+ self.n_heads,
386
+ self.d_head,
387
+ dropout=self.dropout,
388
+ only_cross_attention=self.only_cross_attention,
389
+ dtype=self.dtype,
390
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
391
+ split_head_dim=self.split_head_dim,
392
+ )
393
+ for _ in range(self.depth)
394
+ ]
395
+
396
+ if self.use_linear_projection:
397
+ self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
398
+ else:
399
+ self.proj_out = nn.Conv(
400
+ inner_dim,
401
+ kernel_size=(1, 1),
402
+ strides=(1, 1),
403
+ padding="VALID",
404
+ dtype=self.dtype,
405
+ )
406
+
407
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
408
+
409
+ def __call__(self, hidden_states, context, deterministic=True):
410
+ batch, height, width, channels = hidden_states.shape
411
+ residual = hidden_states
412
+ hidden_states = self.norm(hidden_states)
413
+ if self.use_linear_projection:
414
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
415
+ hidden_states = self.proj_in(hidden_states)
416
+ else:
417
+ hidden_states = self.proj_in(hidden_states)
418
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
419
+
420
+ for transformer_block in self.transformer_blocks:
421
+ hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
422
+
423
+ if self.use_linear_projection:
424
+ hidden_states = self.proj_out(hidden_states)
425
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
426
+ else:
427
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
428
+ hidden_states = self.proj_out(hidden_states)
429
+
430
+ hidden_states = hidden_states + residual
431
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
432
+
433
+
434
+ class FlaxFeedForward(nn.Module):
435
+ r"""
436
+ Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
437
+ [`FeedForward`] class, with the following simplifications:
438
+ - The activation function is currently hardcoded to a gated linear unit from:
439
+ https://arxiv.org/abs/2002.05202
440
+ - `dim_out` is equal to `dim`.
441
+ - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
442
+
443
+ Parameters:
444
+ dim (:obj:`int`):
445
+ Inner hidden states dimension
446
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
447
+ Dropout rate
448
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
449
+ Parameters `dtype`
450
+ """
451
+
452
+ dim: int
453
+ dropout: float = 0.0
454
+ dtype: jnp.dtype = jnp.float32
455
+
456
+ def setup(self):
457
+ # The second linear layer needs to be called
458
+ # net_2 for now to match the index of the Sequential layer
459
+ self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
460
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
461
+
462
+ def __call__(self, hidden_states, deterministic=True):
463
+ hidden_states = self.net_0(hidden_states, deterministic=deterministic)
464
+ hidden_states = self.net_2(hidden_states)
465
+ return hidden_states
466
+
467
+
468
+ class FlaxGEGLU(nn.Module):
469
+ r"""
470
+ Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
471
+ https://arxiv.org/abs/2002.05202.
472
+
473
+ Parameters:
474
+ dim (:obj:`int`):
475
+ Input hidden states dimension
476
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
477
+ Dropout rate
478
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
479
+ Parameters `dtype`
480
+ """
481
+
482
+ dim: int
483
+ dropout: float = 0.0
484
+ dtype: jnp.dtype = jnp.float32
485
+
486
+ def setup(self):
487
+ inner_dim = self.dim * 4
488
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
489
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
490
+
491
+ def __call__(self, hidden_states, deterministic=True):
492
+ hidden_states = self.proj(hidden_states)
493
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
494
+ return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
flux/modules/autoencoder.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from einops import rearrange
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from jax import Array as Tensor
8
+ from flax import nnx
9
+
10
+ from flux.wrapper import TorchWrapper
11
+ from flux.math import dot_product_attention
12
+
13
+
14
+ @dataclass
15
+ class AutoEncoderParams:
16
+ resolution: int
17
+ in_channels: int
18
+ ch: int
19
+ out_ch: int
20
+ ch_mult: list[int]
21
+ num_res_blocks: int
22
+ z_channels: int
23
+ scale_factor: float
24
+ shift_factor: float
25
+
26
+
27
+
28
+ swish = nnx.swish
29
+
30
+
31
+ class AttnBlock(nnx.Module):
32
+ def __init__(self, in_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
33
+ nn = TorchWrapper(rngs, dtype=dtype)
34
+ self.in_channels = in_channels
35
+
36
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
37
+
38
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
39
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
40
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
41
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
42
+
43
+ def attention(self, h_: Tensor) -> Tensor:
44
+ h_ = self.norm(h_)
45
+ q = self.q(h_)
46
+ k = self.k(h_)
47
+ v = self.v(h_)
48
+
49
+ # b, c, h, w = q.shape
50
+ b, h, w, c = q.shape
51
+
52
+ # q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
53
+ # k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
54
+ # v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
55
+ q = rearrange(q, "b h w c -> b 1 (h w) c")
56
+ k = rearrange(k, "b h w c -> b 1 (h w) c")
57
+ v = rearrange(v, "b h w c -> b 1 (h w) c")
58
+ # h_ = nn.functional.scaled_dot_product_attention(q, k, v)
59
+ h_ = dot_product_attention(q, k, v)
60
+
61
+
62
+ # return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
63
+ return rearrange(h_, "b 1 (h w) c -> b h w c", h=h, w=w, c=c, b=b)
64
+
65
+ def __call__(self, x: Tensor) -> Tensor:
66
+ return x + self.proj_out(self.attention(x))
67
+
68
+
69
+ class ResnetBlock(nnx.Module):
70
+ def __init__(self, in_channels: int, out_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
71
+ nn = TorchWrapper(rngs, dtype=dtype)
72
+ self.in_channels = in_channels
73
+ out_channels = in_channels if out_channels is None else out_channels
74
+ self.out_channels = out_channels
75
+
76
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
77
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
78
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
79
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
80
+ if self.in_channels != self.out_channels:
81
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
82
+
83
+ def __call__(self, x):
84
+ h = x
85
+ h = self.norm1(h)
86
+ h = swish(h)
87
+ h = self.conv1(h)
88
+
89
+ h = self.norm2(h)
90
+ h = swish(h)
91
+ h = self.conv2(h)
92
+
93
+ if self.in_channels != self.out_channels:
94
+ x = self.nin_shortcut(x)
95
+
96
+ return x + h
97
+
98
+
99
+ class Downsample(nnx.Module):
100
+ def __init__(self, in_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
101
+ nn = TorchWrapper(rngs, dtype=dtype)
102
+ # no asymmetric padding in torch conv, must do it ourselves
103
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
104
+
105
+ def __call__(self, x: Tensor):
106
+ # pad = (0, 1, 0, 1)
107
+ # x = nn.functional.pad(x, pad, mode="constant", value=0)
108
+ x = jnp.pad(x, ((0, 0), (0, 1), (0, 1), (0, 0)), mode="constant")
109
+ x = self.conv(x)
110
+ return x
111
+
112
+
113
+ class Upsample(nnx.Module):
114
+ def __init__(self, in_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
115
+ nn = TorchWrapper(rngs, dtype=dtype)
116
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
117
+
118
+ def __call__(self, x: Tensor):
119
+ # x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
120
+ B, H, W, C = x.shape
121
+ x = jax.image.resize(x, (B, H * 2, W * 2, C), method="nearest")
122
+ x = self.conv(x)
123
+ return x
124
+
125
+ ResnetBlock_class, Downsample_class, Upsample_class, AttnBlock_class = ResnetBlock, Downsample, Upsample, AttnBlock
126
+
127
+ class Encoder(nnx.Module):
128
+ def __init__(
129
+ self,
130
+ resolution: int,
131
+ in_channels: int,
132
+ ch: int,
133
+ ch_mult: list[int],
134
+ num_res_blocks: int,
135
+ z_channels: int,
136
+ dtype=jnp.float32,
137
+ rngs: nnx.Rngs = None
138
+ ):
139
+ nn = TorchWrapper(rngs, dtype=dtype)
140
+ ResnetBlock, Downsample, Upsample, AttnBlock = nn.declare_with_rng(ResnetBlock_class, Downsample_class, Upsample_class, AttnBlock_class)
141
+ self.ch = ch
142
+ self.num_resolutions = len(ch_mult)
143
+ self.num_res_blocks = num_res_blocks
144
+ self.resolution = resolution
145
+ self.in_channels = in_channels
146
+ # downsampling
147
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
148
+
149
+ curr_res = resolution
150
+ in_ch_mult = (1,) + tuple(ch_mult)
151
+ self.in_ch_mult = in_ch_mult
152
+ self.down = nn.ModuleList()
153
+ block_in = self.ch
154
+ for i_level in range(self.num_resolutions):
155
+ block = nn.ModuleList()
156
+ attn = nn.ModuleList()
157
+ block_in = ch * in_ch_mult[i_level]
158
+ block_out = ch * ch_mult[i_level]
159
+ for _ in range(self.num_res_blocks):
160
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
161
+ block_in = block_out
162
+ down = nn.Module()
163
+ down.block = block
164
+ down.attn = attn
165
+ if i_level != self.num_resolutions - 1:
166
+ down.downsample = Downsample(block_in)
167
+ curr_res = curr_res // 2
168
+ self.down.append(down)
169
+
170
+ # middle
171
+ self.mid = nn.Module()
172
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
173
+ self.mid.attn_1 = AttnBlock(block_in)
174
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
175
+
176
+ # end
177
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
178
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
179
+
180
+ def __call__(self, x: Tensor) -> Tensor:
181
+ # downsampling
182
+ hs = [self.conv_in(x)]
183
+ for i_level in range(self.num_resolutions):
184
+ for i_block in range(self.num_res_blocks):
185
+ h = self.down[i_level].block[i_block](hs[-1])
186
+ if len(self.down[i_level].attn) > 0:
187
+ h = self.down[i_level].attn[i_block](h)
188
+ hs.append(h)
189
+ if i_level != self.num_resolutions - 1:
190
+ hs.append(self.down[i_level].downsample(hs[-1]))
191
+
192
+ # middle
193
+ h = hs[-1]
194
+ h = self.mid.block_1(h)
195
+ h = self.mid.attn_1(h)
196
+ h = self.mid.block_2(h)
197
+ # end
198
+ h = self.norm_out(h)
199
+ h = swish(h)
200
+ h = self.conv_out(h)
201
+ return h
202
+
203
+
204
+ class Decoder(nnx.Module):
205
+ def __init__(
206
+ self,
207
+ ch: int,
208
+ out_ch: int,
209
+ ch_mult: list[int],
210
+ num_res_blocks: int,
211
+ in_channels: int,
212
+ resolution: int,
213
+ z_channels: int,
214
+ dtype=jnp.float32,
215
+ rngs: nnx.Rngs = None
216
+ ):
217
+ nn = TorchWrapper(rngs, dtype=dtype)
218
+ ResnetBlock, Downsample, Upsample, AttnBlock = nn.declare_with_rng(ResnetBlock_class, Downsample_class, Upsample_class, AttnBlock_class)
219
+ self.ch = ch
220
+ self.num_resolutions = len(ch_mult)
221
+ self.num_res_blocks = num_res_blocks
222
+ self.resolution = resolution
223
+ self.in_channels = in_channels
224
+ self.ffactor = 2 ** (self.num_resolutions - 1)
225
+
226
+ # compute in_ch_mult, block_in and curr_res at lowest res
227
+ block_in = ch * ch_mult[self.num_resolutions - 1]
228
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
229
+ self.z_shape = (1, z_channels, curr_res, curr_res)
230
+
231
+ # z to block_in
232
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
233
+
234
+ # middle
235
+ self.mid = nn.Module()
236
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
237
+ self.mid.attn_1 = AttnBlock(block_in)
238
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
239
+
240
+ # upsampling
241
+ self.up = nn.ModuleList()
242
+ for i_level in reversed(range(self.num_resolutions)):
243
+ block = nn.ModuleList()
244
+ attn = nn.ModuleList()
245
+ block_out = ch * ch_mult[i_level]
246
+ for _ in range(self.num_res_blocks + 1):
247
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
248
+ block_in = block_out
249
+ up = nn.Module()
250
+ up.block = block
251
+ up.attn = attn
252
+ if i_level != 0:
253
+ up.upsample = Upsample(block_in)
254
+ curr_res = curr_res * 2
255
+ self.up.insert(0, up) # prepend to get consistent order
256
+
257
+ # end
258
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
259
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
260
+
261
+ def __call__(self, z: Tensor) -> Tensor:
262
+ # z to block_in
263
+ h = self.conv_in(z)
264
+
265
+ # middle
266
+ h = self.mid.block_1(h)
267
+ h = self.mid.attn_1(h)
268
+ h = self.mid.block_2(h)
269
+
270
+ # upsampling
271
+ for i_level in reversed(range(self.num_resolutions)):
272
+ for i_block in range(self.num_res_blocks + 1):
273
+ h = self.up[i_level].block[i_block](h)
274
+ if len(self.up[i_level].attn) > 0:
275
+ h = self.up[i_level].attn[i_block](h)
276
+ if i_level != 0:
277
+ h = self.up[i_level].upsample(h)
278
+
279
+ # end
280
+ h = self.norm_out(h)
281
+ h = swish(h)
282
+ h = self.conv_out(h)
283
+ return h
284
+
285
+
286
+ class DiagonalGaussian(nnx.Module):
287
+ def __init__(self, sample: bool = True, chunk_dim: int = -1, dtype=jnp.float32, rngs: nnx.Rngs = None):
288
+ self.sample = sample
289
+ self.chunk_dim = chunk_dim
290
+ self.rngs = rngs
291
+ self.dtype = dtype
292
+
293
+ def __call__(self, z: Tensor) -> Tensor:
294
+ # mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
295
+ mean, logvar = jnp.split(z, 2, axis=self.chunk_dim)
296
+ if self.sample:
297
+ # std = torch.exp(0.5 * logvar)
298
+ # return mean + std * torch.randn_like(mean)
299
+ std = jnp.exp(0.5 * logvar)
300
+ return mean + std * jax.random.normal(self.rngs(), mean.shape)
301
+ else:
302
+ return mean
303
+
304
+
305
+ Encoder_class, Decoder_class, DiagonalGaussian_class = Encoder, Decoder, DiagonalGaussian
306
+
307
+ class AutoEncoder(nnx.Module):
308
+ def __init__(self, params: AutoEncoderParams, dtype=jnp.float32, rngs: nnx.Rngs = None):
309
+ nn = TorchWrapper(rngs, dtype=dtype)
310
+ Encoder, Decoder, DiagonalGaussian = nn.declare_with_rng(Encoder_class, Decoder_class, DiagonalGaussian_class)
311
+ self.encoder = Encoder(
312
+ resolution=params.resolution,
313
+ in_channels=params.in_channels,
314
+ ch=params.ch,
315
+ ch_mult=params.ch_mult,
316
+ num_res_blocks=params.num_res_blocks,
317
+ z_channels=params.z_channels,
318
+ )
319
+ self.decoder = Decoder(
320
+ resolution=params.resolution,
321
+ in_channels=params.in_channels,
322
+ ch=params.ch,
323
+ out_ch=params.out_ch,
324
+ ch_mult=params.ch_mult,
325
+ num_res_blocks=params.num_res_blocks,
326
+ z_channels=params.z_channels,
327
+ )
328
+ self.reg = DiagonalGaussian()
329
+
330
+ self.scale_factor = params.scale_factor
331
+ self.shift_factor = params.shift_factor
332
+
333
+ def encode(self, x: Tensor) -> Tensor:
334
+ z = self.reg(self.encoder(x))
335
+ z = self.scale_factor * (z - self.shift_factor)
336
+ return z
337
+
338
+ def decode(self, z: Tensor) -> Tensor:
339
+ z = z / self.scale_factor + self.shift_factor
340
+ return self.decoder(z)
341
+
342
+ def __call__(self, x: Tensor) -> Tensor:
343
+ return self.decode(self.encode(x))
flux/modules/conditioner.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flax import nnx
2
+ import jax.numpy as jnp
3
+ from jax import Array as Tensor
4
+
5
+ from transformers import (FlaxCLIPTextModel, CLIPTokenizer, FlaxT5EncoderModel,
6
+ T5Tokenizer)
7
+
8
+
9
+ class HFEmbedder(nnx.Module):
10
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
11
+ self.is_clip = version.startswith("openai")
12
+ self.max_length = max_length
13
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
14
+ dtype = hf_kwargs.get("dtype", jnp.float32)
15
+ if self.is_clip:
16
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
17
+ # self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
18
+ self.hf_module: FlaxCLIPTextModel = FlaxCLIPTextModel.from_pretrained(version, **hf_kwargs)
19
+ else:
20
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
21
+ # self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
22
+ self.hf_module: FlaxT5EncoderModel = FlaxT5EncoderModel.from_pretrained(version, **hf_kwargs)
23
+ if dtype==jnp.bfloat16:
24
+ self.hf_module.params = self.hf_module.to_bf16(self.hf_module.params)
25
+
26
+ def tokenize(self, text: list[str]) -> Tensor:
27
+ batch_encoding = self.tokenizer(
28
+ text,
29
+ truncation=True,
30
+ max_length=self.max_length,
31
+ return_length=False,
32
+ return_overflowing_tokens=False,
33
+ padding="max_length",
34
+ return_tensors="jax",
35
+ )
36
+ return batch_encoding["input_ids"]
37
+
38
+ def __call__(self, input_ids: Tensor) -> Tensor:
39
+ # outputs = self.hf_module(
40
+ # input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
41
+ # attention_mask=None,
42
+ # output_hidden_states=False,
43
+ # )
44
+ outputs = self.hf_module(
45
+ input_ids=input_ids,
46
+ attention_mask=None,
47
+ output_hidden_states=False,
48
+ train=False,
49
+ )
50
+ return outputs[self.output_key]
51
+ # def __call__(self, text: list[str]) -> Tensor:
52
+ # batch_encoding = self.tokenizer(
53
+ # text,
54
+ # truncation=True,
55
+ # max_length=self.max_length,
56
+ # return_length=False,
57
+ # return_overflowing_tokens=False,
58
+ # padding="max_length",
59
+ # return_tensors="jax",
60
+ # )
61
+
62
+ # # outputs = self.hf_module(
63
+ # # input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
64
+ # # attention_mask=None,
65
+ # # output_hidden_states=False,
66
+ # # )
67
+ # outputs = self.hf_module(
68
+ # input_ids=batch_encoding["input_ids"],
69
+ # attention_mask=None,
70
+ # output_hidden_states=False,
71
+ # train=False,
72
+ # )
73
+ # return outputs[self.output_key]
flux/modules/layers.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from jax import Array as Tensor
7
+ from flax import nnx
8
+ from einops import rearrange
9
+
10
+ from flux.wrapper import TorchWrapper
11
+ from flux.math import attention, rope
12
+
13
+
14
+ class EmbedND(nnx.Module):
15
+ def __init__(self, dim: int, theta: int, axes_dim: list[int], dtype=jnp.float32, rngs: nnx.Rngs = None):
16
+ self.dim = dim
17
+ self.theta = theta
18
+ self.axes_dim = axes_dim
19
+
20
+ def __call__(self, ids: Tensor) -> Tensor:
21
+ n_axes = ids.shape[-1]
22
+ # emb = torch.cat(
23
+ # [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
24
+ # dim=-3,
25
+ # )
26
+ emb = jnp.concatenate(
27
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
28
+ axis=-3,
29
+ )
30
+
31
+ # return emb.unsqueeze(1)
32
+ return jnp.expand_dims(emb, 1)
33
+
34
+
35
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
36
+ """
37
+ Create sinusoidal timestep embeddings.
38
+ :param t: a 1-D Tensor of N indices, one per batch element.
39
+ These may be fractional.
40
+ :param dim: the dimension of the output.
41
+ :param max_period: controls the minimum frequency of the embeddings.
42
+ :return: an (N, D) Tensor of positional embeddings.
43
+ """
44
+ t = time_factor * t
45
+ half = dim // 2
46
+ # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
47
+ # t.device
48
+ # )
49
+
50
+ freqs = jnp.exp(-math.log(max_period) * jnp.arange(half, dtype=jnp.float32) / half)
51
+
52
+ # args = t[:, None].float() * freqs[None]
53
+ # embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ args = t[:, None] * freqs[None]
55
+ embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
56
+ if dim % 2:
57
+ # embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
58
+ embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
59
+ # if torch.is_floating_point(t):
60
+ # embedding = embedding.to(t)
61
+ # return embedding
62
+ if jnp.issubdtype(t.dtype, jnp.floating):
63
+ embedding = embedding.astype(t.dtype)
64
+ return embedding
65
+
66
+
67
+ class MLPEmbedder(nnx.Module):
68
+ def __init__(self, in_dim: int, hidden_dim: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
69
+ nn = TorchWrapper(rngs=rngs, dtype=dtype)
70
+
71
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
72
+ self.silu = nn.SiLU()
73
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
74
+
75
+ def __call__(self, x: Tensor) -> Tensor:
76
+ return self.out_layer(self.silu(self.in_layer(x)))
77
+
78
+
79
+ class RMSNorm(nnx.Module):
80
+ def __init__(self, dim: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
81
+ nn = TorchWrapper(rngs=rngs, dtype=dtype)
82
+ # self.scale = nn.Parameter(torch.ones(dim))
83
+ self.scale = nn.Parameter(jnp.ones((dim,)))
84
+
85
+
86
+ def __call__(self, x: Tensor):
87
+ x_dtype = x.dtype
88
+ # x = x.float()
89
+ x = x.astype(jnp.float32)
90
+ # rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
91
+ rrms = jax.lax.rsqrt(jnp.mean(x**2, axis=-1, keepdims=True) + 1e-6)
92
+ # return (x * rrms).to(dtype=x_dtype) * self.scale
93
+ return (x * rrms).astype(x.dtype) * self.scale
94
+
95
+
96
+ RMSNorm_class = RMSNorm
97
+
98
+ class QKNorm(nnx.Module):
99
+ def __init__(self, dim: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
100
+ nn = TorchWrapper(rngs=rngs, dtype=dtype)
101
+ RMSNorm = nn.declare_with_rng(RMSNorm_class)
102
+ self.query_norm = RMSNorm(dim)
103
+ self.key_norm = RMSNorm(dim)
104
+
105
+ def __call__(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
106
+ q = self.query_norm(q)
107
+ k = self.key_norm(k)
108
+ # return q.to(v), k.to(v)
109
+ return q.astype(v.dtype), k.astype(v.dtype)
110
+
111
+
112
+ QKNorm_class = QKNorm
113
+
114
+ class SelfAttention(nnx.Module):
115
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=jnp.float32, rngs: nnx.Rngs = None):
116
+ nn = TorchWrapper(rngs=rngs, dtype=dtype)
117
+ QKNorm = nn.declare_with_rng(QKNorm_class)
118
+ self.num_heads = num_heads
119
+ head_dim = dim // num_heads
120
+
121
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
122
+ self.norm = QKNorm(head_dim)
123
+ self.proj = nn.Linear(dim, dim)
124
+
125
+ def __call__(self, x: Tensor, pe: Tensor) -> Tensor:
126
+ qkv = self.qkv(x)
127
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
128
+ q, k = self.norm(q, k, v)
129
+ x = attention(q, k, v, pe=pe)
130
+ x = self.proj(x)
131
+ return x
132
+
133
+
134
+ @dataclass
135
+ class ModulationOut:
136
+ shift: Tensor
137
+ scale: Tensor
138
+ gate: Tensor
139
+
140
+
141
+ class Modulation(nnx.Module):
142
+ def __init__(self, dim: int, double: bool, dtype=jnp.float32, rngs: nnx.Rngs = None):
143
+ nn = TorchWrapper(rngs=rngs, dtype=dtype)
144
+ self.is_double = double
145
+ self.multiplier = 6 if double else 3
146
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
147
+
148
+ def __call__(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
149
+ # out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
150
+ out = self.lin(nnx.silu(vec))[:, None, :]
151
+ out = jnp.split(out, self.multiplier, axis=-1)
152
+ return (
153
+ ModulationOut(*out[:3]),
154
+ ModulationOut(*out[3:]) if self.is_double else None,
155
+ )
156
+
157
+ Modulation_class, SelfAttention_class = Modulation, SelfAttention
158
+
159
+ class DoubleStreamBlock(nnx.Module):
160
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=jnp.float32, rngs: nnx.Rngs = None):
161
+ nn = TorchWrapper(rngs=rngs, dtype=dtype)
162
+ Modulation, SelfAttention = nn.declare_with_rng(Modulation_class, SelfAttention_class)
163
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
164
+ self.num_heads = num_heads
165
+ self.hidden_size = hidden_size
166
+ self.img_mod = Modulation(hidden_size, double=True)
167
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
168
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
169
+
170
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
171
+ self.img_mlp = nn.Sequential(
172
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
173
+ nn.GELU(approximate="tanh"),
174
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
175
+ )
176
+
177
+ self.txt_mod = Modulation(hidden_size, double=True)
178
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
179
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
180
+
181
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
182
+ self.txt_mlp = nn.Sequential(
183
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
184
+ nn.GELU(approximate="tanh"),
185
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
186
+ )
187
+
188
+ def __call__(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
189
+ img_mod1, img_mod2 = self.img_mod(vec)
190
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
191
+
192
+ # prepare image for attention
193
+ img_modulated = self.img_norm1(img)
194
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
195
+ img_qkv = self.img_attn.qkv(img_modulated)
196
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
197
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
198
+
199
+ # prepare txt for attention
200
+ txt_modulated = self.txt_norm1(txt)
201
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
202
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
203
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
204
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
205
+
206
+ # run actual attention
207
+ # q = torch.cat((txt_q, img_q), dim=2)
208
+ # k = torch.cat((txt_k, img_k), dim=2)
209
+ # v = torch.cat((txt_v, img_v), dim=2)
210
+ q = jnp.concatenate((txt_q, img_q), axis=2)
211
+ k = jnp.concatenate((txt_k, img_k), axis=2)
212
+ v = jnp.concatenate((txt_v, img_v), axis=2)
213
+
214
+
215
+ attn = attention(q, k, v, pe=pe)
216
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
217
+
218
+ # calculate the img bloks
219
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
220
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
221
+
222
+ # calculate the txt bloks
223
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
224
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
225
+ return img, txt
226
+
227
+
228
+ class SingleStreamBlock(nnx.Module):
229
+ """
230
+ A DiT block with parallel linear layers as described in
231
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
232
+ """
233
+
234
+ def __init__(
235
+ self,
236
+ hidden_size: int,
237
+ num_heads: int,
238
+ mlp_ratio: float = 4.0,
239
+ qk_scale: float | None = None,
240
+ dtype=jnp.float32, rngs: nnx.Rngs = None
241
+ ):
242
+ nn = TorchWrapper(rngs=rngs, dtype=dtype)
243
+ QKNorm, Modulation = nn.declare_with_rng(QKNorm_class, Modulation_class)
244
+ self.hidden_dim = hidden_size
245
+ self.num_heads = num_heads
246
+ head_dim = hidden_size // num_heads
247
+ self.scale = qk_scale or head_dim**-0.5
248
+
249
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
250
+ # qkv and mlp_in
251
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
252
+ # proj and mlp_out
253
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
254
+
255
+ self.norm = QKNorm(head_dim)
256
+
257
+ self.hidden_size = hidden_size
258
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
259
+
260
+ self.mlp_act = nn.GELU(approximate="tanh")
261
+ self.modulation = Modulation(hidden_size, double=False)
262
+
263
+ def __call__(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
264
+ mod, _ = self.modulation(vec)
265
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
266
+ # qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
267
+ qkv, mlp = jnp.split(self.linear1(x_mod), [3 * self.hidden_size,], axis=-1)
268
+
269
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
270
+ q, k = self.norm(q, k, v)
271
+
272
+ # compute attention
273
+ attn = attention(q, k, v, pe=pe)
274
+ # compute activation in mlp stream, cat again and run second linear layer
275
+ # output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
276
+ output = self.linear2(jnp.concatenate((attn, self.mlp_act(mlp)), axis=2))
277
+ return x + mod.gate * output
278
+
279
+
280
+ class LastLayer(nnx.Module):
281
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=jnp.float32, rngs: nnx.Rngs = None):
282
+ nn = TorchWrapper(rngs=rngs, dtype=dtype)
283
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
284
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
285
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
286
+
287
+ def __call__(self, x: Tensor, vec: Tensor) -> Tensor:
288
+ # shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
289
+ shift, scale = jnp.split(self.adaLN_modulation(vec), 2, axis=1)
290
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
291
+ x = self.linear(x)
292
+ return x
flux/sampling.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ from einops import rearrange, repeat
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from jax import Array as Tensor
9
+ from flax import nnx
10
+
11
+ from flux.model import Flux
12
+ from flux.modules.conditioner import HFEmbedder
13
+
14
+
15
+ def get_noise(
16
+ num_samples: int,
17
+ height: int,
18
+ width: int,
19
+ device,
20
+ dtype: jnp.dtype,
21
+ seed: int,
22
+ ):
23
+ # return torch.randn(
24
+ # num_samples,
25
+ # 16,
26
+ # # allow for packing
27
+ # 2 * math.ceil(height / 16),
28
+ # 2 * math.ceil(width / 16),
29
+ # device=device,
30
+ # dtype=dtype,
31
+ # generator=torch.Generator(device=device).manual_seed(seed),
32
+ # )
33
+ # rngs = nnx.Rngs(seed)
34
+ key = jax.random.key(seed)
35
+ return jax.random.normal(
36
+ # rngs(),
37
+ key,
38
+ (
39
+ num_samples,
40
+ 2 * math.ceil(height / 16),
41
+ 2 * math.ceil(width / 16),
42
+ 16,
43
+ ),
44
+ dtype=dtype
45
+ )
46
+
47
+
48
+ def prepare_tokens(t5: HFEmbedder, clip: HFEmbedder, prompt: str | list[str]) -> tuple[Tensor, Tensor]:
49
+ if isinstance(prompt, str):
50
+ prompt = [prompt]
51
+ t5_tokens = t5.tokenize(prompt)
52
+ clip_tokens = clip.tokenize(prompt)
53
+ return t5_tokens, clip_tokens
54
+ # return {
55
+ # "t5": t5_tokens,
56
+ # "clip": clip_tokens,
57
+ # }
58
+
59
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, t5_tokens: Tensor, clip_tokens: Tensor) -> dict[str, Tensor]:
60
+ # bs, c, h, w = img.shape
61
+ bs, h, w, c = img.shape
62
+
63
+ if bs == 1:
64
+ bs = t5_tokens.shape[0]
65
+
66
+ # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
67
+ img = rearrange(img, "b (h ph) (w pw) c -> b (h w) (c ph pw)", ph=2, pw=2)
68
+ if img.shape[0] == 1 and bs > 1:
69
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
70
+
71
+ # img_ids = torch.zeros(h // 2, w // 2, 3)
72
+ img_ids = jnp.zeros((h // 2, w // 2, 3))
73
+ # img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
74
+ # img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
75
+ img_ids = img_ids.at[..., 1].set(img_ids[..., 1]+jnp.arange(h // 2)[:, None])
76
+ img_ids = img_ids.at[..., 2].set(img_ids[..., 2]+jnp.arange(w // 2)[None, :])
77
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
78
+
79
+ # if isinstance(prompt, str):
80
+ # prompt = [prompt]
81
+ txt = t5(t5_tokens)
82
+ if txt.shape[0] == 1 and bs > 1:
83
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
84
+ # txt_ids = torch.zeros(bs, txt.shape[1], 3)
85
+ txt_ids = jnp.zeros((bs, txt.shape[1], 3))
86
+
87
+ vec = clip(clip_tokens)
88
+ if vec.shape[0] == 1 and bs > 1:
89
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
90
+
91
+ # return {
92
+ # "img": img,
93
+ # "img_ids": img_ids.to(img.device),
94
+ # "txt": txt.to(img.device),
95
+ # "txt_ids": txt_ids.to(img.device),
96
+ # "vec": vec.to(img.device),
97
+ # }
98
+ return {
99
+ "img": img,
100
+ "img_ids": img_ids,
101
+ "txt": txt,
102
+ "txt_ids": txt_ids,
103
+ "vec": vec,
104
+ }
105
+
106
+
107
+ def time_shift(mu: float, sigma: float, t: Tensor):
108
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
109
+ # return jnp.exp(mu) / (jnp.exp(mu) + (1 / t - 1) ** sigma)
110
+
111
+
112
+ def get_lin_function(
113
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
114
+ ) -> Callable[[float], float]:
115
+ m = (y2 - y1) / (x2 - x1)
116
+ b = y1 - m * x1
117
+ return lambda x: m * x + b
118
+
119
+
120
+ def get_schedule(
121
+ num_steps: int,
122
+ image_seq_len: int,
123
+ base_shift: float = 0.5,
124
+ max_shift: float = 1.15,
125
+ shift: bool = True,
126
+ ) -> Tensor:
127
+ # extra step for zero
128
+ # timesteps = torch.linspace(1, 0, num_steps + 1)
129
+ timesteps = jnp.linspace(1, 0, num_steps + 1)
130
+
131
+ # shifting the schedule to favor high timesteps for higher signal images
132
+ if shift:
133
+ # estimate mu based on linear estimation between two points
134
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
135
+ timesteps = time_shift(mu, 1.0, timesteps)
136
+
137
+ return timesteps#.tolist()
138
+
139
+ DEBUG=False
140
+
141
+ def denoise_for(
142
+ model: Flux,
143
+ # model input
144
+ img: Tensor,
145
+ img_ids: Tensor,
146
+ txt: Tensor,
147
+ txt_ids: Tensor,
148
+ vec: Tensor,
149
+ # sampling parameters
150
+ timesteps: Tensor,
151
+ guidance: float = 4.0,
152
+ ):
153
+ # this is ignored for schnell
154
+ # guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
155
+ guidance_vec = jnp.full((img.shape[0],), guidance, dtype=img.dtype)
156
+ timesteps = timesteps.tolist()
157
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
158
+ # t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
159
+ t_vec = jnp.full((img.shape[0],), t_curr, dtype=img.dtype)
160
+ pred = model(
161
+ img=img,
162
+ img_ids=img_ids,
163
+ txt=txt,
164
+ txt_ids=txt_ids,
165
+ y=vec,
166
+ timesteps=t_vec,
167
+ guidance=guidance_vec,
168
+ )
169
+
170
+ img = img + (t_prev - t_curr) * pred
171
+ return img
172
+
173
+
174
+ # @nnx.jit
175
+ def denoise(
176
+ model: Flux,
177
+ # model input
178
+ img: Tensor,
179
+ img_ids: Tensor,
180
+ txt: Tensor,
181
+ txt_ids: Tensor,
182
+ vec: Tensor,
183
+ # sampling parameters
184
+ timesteps: Tensor,
185
+ guidance: float = 4.0,
186
+ ):
187
+ # this is ignored for schnell
188
+ # guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
189
+ guidance_vec = jnp.full((img.shape[0],), guidance, dtype=img.dtype)
190
+ @nnx.scan
191
+ def scan_func(acc, t_prev):
192
+ img, t_curr = acc
193
+ dtype = img.dtype
194
+ t_vec = jnp.full((img.shape[0],), t_curr, dtype=img.dtype)
195
+ pred = model(
196
+ img=img,
197
+ img_ids=img_ids,
198
+ txt=txt,
199
+ txt_ids=txt_ids,
200
+ y=vec,
201
+ timesteps=t_vec,
202
+ guidance=guidance_vec,
203
+ )
204
+
205
+ img = img + (t_prev - t_curr) * pred
206
+ return (img.astype(dtype), t_prev), pred
207
+ acc,pred=scan_func((img, timesteps[0]), timesteps[1:])
208
+ return acc[0]
209
+
210
+
211
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
212
+ # return rearrange(
213
+ # x,
214
+ # "b (h w) (c ph pw) -> b c (h ph) (w pw)",
215
+ # h=math.ceil(height / 16),
216
+ # w=math.ceil(width / 16),
217
+ # ph=2,
218
+ # pw=2,
219
+ # )
220
+ return rearrange(
221
+ x,
222
+ "b (h w) (c ph pw) -> b (h ph) (w pw) c",
223
+ h=math.ceil(height / 16),
224
+ w=math.ceil(width / 16),
225
+ ph=2,
226
+ pw=2,
227
+ )
flux/util.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import jax
6
+ from jax import Array as Tensor
7
+ import jax.numpy as jnp
8
+ from flax import nnx
9
+ import torch
10
+ from einops import rearrange
11
+ from huggingface_hub import hf_hub_download
12
+ from imwatermark import WatermarkEncoder
13
+ from safetensors.torch import load_file as load_sft
14
+
15
+ from flux.model import Flux, FluxParams
16
+ from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
17
+ from flux.modules.conditioner import HFEmbedder
18
+
19
+
20
+
21
+
22
+ @dataclass
23
+ class ModelSpec:
24
+ params: FluxParams
25
+ ae_params: AutoEncoderParams
26
+ ckpt_path: str | None
27
+ ae_path: str | None
28
+ repo_id: str | None
29
+ repo_flow: str | None
30
+ repo_ae: str | None
31
+
32
+
33
+ configs = {
34
+ "flux-dev": ModelSpec(
35
+ repo_id="black-forest-labs/FLUX.1-dev",
36
+ repo_flow="flux1-dev.safetensors",
37
+ repo_ae="ae.safetensors",
38
+ ckpt_path=os.getenv("FLUX_DEV"),
39
+ params=FluxParams(
40
+ in_channels=64,
41
+ vec_in_dim=768,
42
+ context_in_dim=4096,
43
+ hidden_size=3072,
44
+ mlp_ratio=4.0,
45
+ num_heads=24,
46
+ depth=19,
47
+ depth_single_blocks=38,
48
+ axes_dim=[16, 56, 56],
49
+ theta=10_000,
50
+ qkv_bias=True,
51
+ guidance_embed=True,
52
+ ),
53
+ ae_path=os.getenv("AE"),
54
+ ae_params=AutoEncoderParams(
55
+ resolution=256,
56
+ in_channels=3,
57
+ ch=128,
58
+ out_ch=3,
59
+ ch_mult=[1, 2, 4, 4],
60
+ num_res_blocks=2,
61
+ z_channels=16,
62
+ scale_factor=0.3611,
63
+ shift_factor=0.1159,
64
+ ),
65
+ ),
66
+ "flux-schnell": ModelSpec(
67
+ repo_id="black-forest-labs/FLUX.1-schnell",
68
+ repo_flow="flux1-schnell.safetensors",
69
+ repo_ae="ae.safetensors",
70
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
71
+ params=FluxParams(
72
+ in_channels=64,
73
+ vec_in_dim=768,
74
+ context_in_dim=4096,
75
+ hidden_size=3072,
76
+ mlp_ratio=4.0,
77
+ num_heads=24,
78
+ depth=19,
79
+ depth_single_blocks=38,
80
+ axes_dim=[16, 56, 56],
81
+ theta=10_000,
82
+ qkv_bias=True,
83
+ guidance_embed=False,
84
+ ),
85
+ ae_path=os.getenv("AE"),
86
+ ae_params=AutoEncoderParams(
87
+ resolution=256,
88
+ in_channels=3,
89
+ ch=128,
90
+ out_ch=3,
91
+ ch_mult=[1, 2, 4, 4],
92
+ num_res_blocks=2,
93
+ z_channels=16,
94
+ scale_factor=0.3611,
95
+ shift_factor=0.1159,
96
+ ),
97
+ ),
98
+ }
99
+
100
+
101
+ try:
102
+ import ml_dtypes
103
+ from_torch_bf16 = lambda x: jnp.asarray(x.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16))
104
+ except:
105
+ from_torch_bf16 = lambda x: jnp.asarray(x.float().numpy()).astype(jnp.bfloat16)
106
+
107
+ def load_from_torch(graph, state, state_dict:dict):
108
+ cnt=0
109
+ torch_cnt=0
110
+ flax_cnt=0
111
+ val_cnt=0
112
+ print(f"Torch states: #{len(state_dict)}; Flax states: #{len(state.flat_state())}")
113
+ def convert_to_jax(tensor):
114
+ if tensor.dtype==torch.bfloat16:
115
+ return from_torch_bf16(tensor)
116
+ else:
117
+ return jnp.asarray(tensor.numpy())
118
+ for key in sorted(state_dict.keys()):
119
+ ptr=state
120
+ node=graph
121
+ torch_cnt+=1
122
+ # print(key)
123
+ try:
124
+ for loc in key.split(".")[:-1]:
125
+ if loc.isnumeric():
126
+ if "layers" in ptr:
127
+ ptr=ptr["layers"]
128
+ node=node.subgraphs["layers"]
129
+ loc=int(loc)
130
+ ptr=ptr[loc]
131
+ node=node.subgraphs[loc]
132
+ last=key.split(".")[-1]
133
+ if last not in ptr._mapping.keys():
134
+ ptr_keys=list(ptr._mapping.keys())
135
+ ptr_keys=list(filter(lambda x:x!="bias", ptr_keys))
136
+ if len(ptr_keys)==1:
137
+ ptr_key=ptr_keys[0]
138
+ elif last=="weight" and "kernel" in ptr_keys:
139
+ ptr_key="kernel"
140
+ else:
141
+ cnt+=1
142
+ raise Exception(f"Mismatched: {key}: {ptr_keys} ")
143
+ val=ptr[ptr_key].value
144
+ # assert state_dict[key].shape==val.shape, f"[{node.type}]mismatched {state_dict[key].shape} {val.shape}"
145
+ else:
146
+ if isinstance(ptr[last], jax.Array):
147
+ val=ptr[last]
148
+ else:
149
+ val=ptr[last].value
150
+ ptr_key=last
151
+ assert state_dict[key].shape==val.shape, f"{key} mismatched"
152
+
153
+ if isinstance(ptr[ptr_key], jax.Array):
154
+ assert state_dict[key].shape==val.shape, f"Array: [{node.type}]mismatched {state_dict[key].shape} {val.shape}"
155
+ kernel=convert_to_jax(state_dict[key])
156
+ val_cnt+=1
157
+ continue
158
+ elif ptr_key=="bias":
159
+ assert state_dict[key].shape==val.shape, f"Bias: [{node.type}]mismatched {state_dict[key].shape} {val.shape}"
160
+ kernel=nnx.Param(convert_to_jax(state_dict[key])).to_state()
161
+ else:
162
+ # print(node.type,node.attributes, )
163
+ # print(type(ptr._mapping[ptr_key]))
164
+ if 'kernel_size' in node.attributes:
165
+ kernel=convert_to_jax(state_dict[key])
166
+ # print(len(kernel.shape))
167
+ # print(kernel.shape)
168
+ if len(kernel.shape)==3:
169
+ kernel=jnp.transpose(kernel, (2, 1, 0))
170
+ elif len(kernel.shape)==4:
171
+ kernel=jnp.transpose(kernel, (2, 3, 1, 0))
172
+ elif len(kernel.shape)==5:
173
+ kernel=jnp.transpose(kernel, (2, 3, 4, 1, 0))
174
+ elif 'dot_general' in node.attributes:
175
+ kernel=convert_to_jax(state_dict[key])
176
+ kernel=jnp.transpose(kernel, (1, 0))
177
+ else:
178
+ # val=ptr[ptr_key].value
179
+ kernel=convert_to_jax(state_dict[key])
180
+ assert val.shape==kernel.shape, f"[{node.type}]mismatched {val.shape} {kernel.shape}"
181
+ kernel=nnx.Param(kernel).to_state()
182
+ # print("new", len(kernel.value.shape), type(kernel))
183
+ ptr._mapping[ptr_key]=kernel
184
+ flax_cnt+=1
185
+ except Exception as e:
186
+ print(e, f"{key}")
187
+ print(cnt, torch_cnt, flax_cnt, val_cnt)
188
+ # print(len(state.flat_state()))
189
+ return state
190
+
191
+ def load_state_dict(model, state_dict):
192
+ graph,state=nnx.split(model)
193
+ state=load_from_torch(graph, state, state_dict)
194
+ nnx.update(model, state)
195
+ return model
196
+
197
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
198
+ if len(missing) > 0 and len(unexpected) > 0:
199
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
200
+ print("\n" + "-" * 79 + "\n")
201
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
202
+ elif len(missing) > 0:
203
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
204
+ elif len(unexpected) > 0:
205
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
206
+
207
+
208
+ def patch_dtype(model,dtype,patch_param=False):
209
+ for path, module in model.iter_modules():
210
+ if hasattr(module, "dtype") and (module.dtype is None or jnp.issubdtype(module.dtype, jnp.floating)):
211
+ module.dtype=dtype
212
+ if patch_param:
213
+ if hasattr(module, "param_dtype") and jnp.issubdtype(module.param_dtype, jnp.floating):
214
+ module.param_dtype=dtype
215
+ if not patch_param:
216
+ return model
217
+ for path, parent in nnx.iter_graph(model):
218
+ if isinstance(parent, nnx.Module):
219
+ for name, value in vars(parent).items():
220
+ if isinstance(value, nnx.Variable) and value.value is None:
221
+ pass
222
+ # print(name)
223
+ elif isinstance(value, nnx.Variable):
224
+ if jnp.issubdtype(value.value.dtype, jnp.floating):
225
+ value.value = value.value.astype(dtype)
226
+ # print(name,value.value.dtype,value.dtype)
227
+ elif isinstance(value,jax.Array):
228
+ # print(name,value.dtype)
229
+ # print(parent.__getattribute__(name).dtype)
230
+ if jnp.issubdtype(value.dtype, jnp.floating):
231
+ parent.__setattr__(name,value.astype(dtype))
232
+ return model
233
+
234
+
235
+ def load_flow_model(name: str, device: str = "none", hf_download: bool = True):
236
+ # Loading Flux
237
+ print("Init model")
238
+ ckpt_path = configs[name].ckpt_path
239
+ if (
240
+ ckpt_path is None
241
+ and configs[name].repo_id is not None
242
+ and configs[name].repo_flow is not None
243
+ and hf_download
244
+ ):
245
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
246
+
247
+ # with torch.device("meta" if ckpt_path is not None else device):
248
+ model = Flux(configs[name].params, dtype=jnp.bfloat16, rngs=nnx.Rngs(0))
249
+ model = patch_dtype(model, jnp.bfloat16)
250
+ if ckpt_path is not None:
251
+ print("Loading checkpoint")
252
+ # load_sft doesn't support torch.device
253
+ sd = load_sft(ckpt_path, device="cpu")
254
+ # TODO: loading state_dict
255
+ model = load_state_dict(model, sd)
256
+ # missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
257
+ # print_load_warning(missing, unexpected)
258
+ return model
259
+
260
+
261
+ def load_t5(device: str = "none", max_length: int = 512) -> HFEmbedder:
262
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
263
+ return HFEmbedder("lnyan/t5-v1_1-xxl-encoder", max_length=max_length, dtype=jnp.bfloat16)
264
+
265
+
266
+ def load_clip(device: str = "none") -> HFEmbedder:
267
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, dtype=jnp.bfloat16)
268
+
269
+
270
+ def load_ae(name: str, device: str = "none", hf_download: bool = True) -> AutoEncoder:
271
+ ckpt_path = configs[name].ae_path
272
+ if (
273
+ ckpt_path is None
274
+ and configs[name].repo_id is not None
275
+ and configs[name].repo_ae is not None
276
+ and hf_download
277
+ ):
278
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
279
+
280
+ # Loading the autoencoder
281
+ print("Init AE")
282
+ # with torch.device("meta" if ckpt_path is not None else device):
283
+ ae = AutoEncoder(configs[name].ae_params, dtype=jnp.bfloat16, rngs=nnx.Rngs(0))
284
+ ae = patch_dtype(ae, jnp.bfloat16)
285
+
286
+ if ckpt_path is not None:
287
+ sd = load_sft(ckpt_path, device="cpu")
288
+ # TODO: loading state_dict
289
+ ae = load_state_dict(ae, sd)
290
+ # missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
291
+ # print_load_warning(missing, unexpected)
292
+ return ae
293
+
294
+
295
+ class WatermarkEmbedder:
296
+ def __init__(self, watermark):
297
+ self.watermark = watermark
298
+ self.num_bits = len(WATERMARK_BITS)
299
+ self.encoder = WatermarkEncoder()
300
+ self.encoder.set_watermark("bits", self.watermark)
301
+
302
+ def __call__(self, image: Tensor) -> Tensor:
303
+ """
304
+ Adds a predefined watermark to the input image
305
+
306
+ Args:
307
+ image: ([N,] B, RGB, H, W) in range [-1, 1]
308
+
309
+ Returns:
310
+ same as input but watermarked
311
+ """
312
+ image = 0.5 * image + 0.5
313
+ squeeze = len(image.shape) == 4
314
+ if squeeze:
315
+ image = image[None, ...]
316
+ n = image.shape[0]
317
+ # image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
318
+ image_np = np.array(rearrange((255 * image), "n b h w c -> (n b) h w c"))[:, :, :, ::-1]
319
+
320
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
321
+ # watermarking libary expects input as cv2 BGR format
322
+ for k in range(image_np.shape[0]):
323
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
324
+ # image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
325
+ # image.device
326
+ # )
327
+ image = jnp.asarray(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b h w c", n=n))
328
+ # image = torch.clamp(image / 255, min=0.0, max=1.0)
329
+ image = jnp.clip(image / 255, min=0.0, max=1.0)
330
+ if squeeze:
331
+ image = image[0]
332
+ image = 2 * image - 1
333
+ return image
334
+
335
+
336
+ # A fixed 48-bit message that was chosen at random
337
+ WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
338
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
339
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
340
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
flux/wrapper.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Lnyan (https://github.com/lkwq007). All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from functools import partial
17
+
18
+
19
+ import numpy as np
20
+ import jax
21
+ import jax.numpy as jnp
22
+ from jax import Array as Tensor
23
+ import flax
24
+ from flax import nnx
25
+ import flax.linen
26
+
27
+
28
+ def fake_init(key, feature_shape, param_dtype):
29
+ return jax.ShapeDtypeStruct(feature_shape, param_dtype)
30
+
31
+
32
+ def wrap_LayerNorm(dim, *, eps=1e-5, elementwise_affine=True, bias=True, rngs:nnx.Rngs):
33
+ return nnx.LayerNorm(dim, epsilon=eps, use_bias=elementwise_affine and bias, use_scale=elementwise_affine, bias_init=fake_init, scale_init=fake_init, rngs=rngs)
34
+
35
+ def wrap_Linear(dim, inner_dim, *, bias=True, rngs:nnx.Rngs):
36
+ return nnx.Linear(dim, inner_dim, use_bias=bias, kernel_init=fake_init, bias_init=fake_init, rngs=rngs)
37
+
38
+
39
+ def wrap_GroupNorm(num_groups, num_channels, *, eps=1e-5, affine=True, rngs:nnx.Rngs):
40
+ return nnx.GroupNorm(num_channels, num_groups=num_groups, epsilon=eps, use_bias=affine, use_scale=affine, bias_init=fake_init, scale_init=fake_init, rngs=rngs)
41
+
42
+ def wrap_Conv(in_channels, out_channels, kernel_size, *, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', rngs:nnx.Rngs, conv_dim:int):
43
+ if isinstance(kernel_size, int):
44
+ kernel_tuple = (kernel_size,) * conv_dim
45
+ else:
46
+ # elif isinstance(kernel_size, tuple):
47
+ assert len(kernel_size) == conv_dim
48
+ kernel_tuple = kernel_size
49
+ return nnx.Conv(in_channels, out_channels, kernel_tuple, strides=stride, padding=padding, use_bias=bias, kernel_init=fake_init, bias_init=fake_init, rngs=rngs)
50
+ # return nnx.Conv(in_channels, out_channels, kernel_tuple, stride=stride, padding=padding, dilation=dilation, feature_group_count=groups, use_bias=bias, rngs=rngs)
51
+
52
+
53
+ class nn_GELU(nnx.Module):
54
+ def __init__(self, approximate="none") -> None:
55
+ self.approximate=approximate=="tanh"
56
+
57
+ def __call__(self, x):
58
+ return nnx.gelu(x, approximate=self.approximate)
59
+
60
+ class nn_SiLU(nnx.Module):
61
+ def __init__(self) -> None:
62
+ pass
63
+
64
+ def __call__(self, x):
65
+ return nnx.silu(x)
66
+
67
+ class nn_AvgPool(nnx.Module):
68
+ def __init__(self, window_shape, strides=None, padding="VALID") -> None:
69
+ self.window_shape=window_shape
70
+ self.strides=strides
71
+ self.padding=padding
72
+
73
+ def __call__(self, x):
74
+ return flax.linen.avg_pool(x, window_shape=self.window_shape, strides=self.strides, padding=self.padding)
75
+
76
+
77
+ # a wrapper class
78
+ class TorchWrapper:
79
+ def __init__(self, rngs: nnx.Rngs, dtype=jnp.float32):
80
+ self.rngs = rngs
81
+ self.dtype = dtype
82
+
83
+ def declare_with_rng(self, *args):
84
+ ret=list(map(lambda f: partial(f, dtype=self.dtype, rngs=self.rngs), args))
85
+ return ret if len(ret)>1 else ret[0]
86
+
87
+ def conv_nd(self, dims, *args, **kwargs):
88
+ return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=dims)
89
+
90
+ def avg_pool(self, *args, **kwargs):
91
+ return nn_AvgPool(*args, **kwargs)
92
+
93
+
94
+ def linear(self, *args, **kwargs):
95
+ return self.Linear(*args, **kwargs)
96
+
97
+ def SiLU(self):
98
+ return nn_SiLU()
99
+
100
+ def GELU(self, approximate="none"):
101
+ return nn_GELU(approximate)
102
+
103
+ def Identity(self):
104
+ return lambda x: x
105
+
106
+ def LayerNorm(self, dim, eps=1e-5, elementwise_affine=True, bias=True):
107
+ return wrap_LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, bias=bias, rngs=self.rngs)
108
+
109
+ def GroupNorm(self, *args, **kwargs):
110
+ return wrap_GroupNorm(*args,**kwargs, rngs=self.rngs)
111
+
112
+ def Linear(self, *args, **kwargs):
113
+ return wrap_Linear(*args, **kwargs, rngs=self.rngs)
114
+
115
+ def Parameter(self, value):
116
+ return nnx.Param(value)
117
+
118
+ def Dropout(self, p):
119
+ return nnx.Dropout(rate=p, rngs=self.rngs)
120
+
121
+ def Sequential(self, *args):
122
+ return nnx.Sequential(*args)
123
+
124
+ def Conv1d(self, *args, **kwargs):
125
+ return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=1)
126
+
127
+ def Conv2d(self, *args, **kwargs):
128
+ return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=2)
129
+
130
+ def Conv3d(self, *args, **kwargs):
131
+ return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=3)
132
+
133
+ def ModuleList(self, lst=None):
134
+ if lst is None:
135
+ return []
136
+ return list(lst)
137
+
138
+ def Module(self,*args,**kwargs):
139
+ return nnx.Dict()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ jax[cuda12]
2
+ flax==0.9.0
3
+ flash_attn_jax
4
+ torch
5
+ torchvision
6
+ opencv-python-headless
7
+ einops
8
+ huggingface_hub
9
+ transformers
10
+ tokenizers
11
+ sentencepiece
12
+ fire
13
+ invisible-watermark
14
+ ml-dtypes