PommesPeter commited on
Commit
5aadc08
1 Parent(s): 24e677f

Upload 8 files

Browse files
Files changed (6) hide show
  1. app.py +598 -0
  2. models/__init__.py +2 -0
  3. models/components.py +54 -0
  4. models/model.py +908 -0
  5. models/model_5b.py +894 -0
  6. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import builtins
3
+ import json
4
+ import multiprocessing as mp
5
+ import os, sys
6
+ import random
7
+ import socket
8
+ import traceback
9
+
10
+ import fairscale.nn.model_parallel.initialize as fs_init
11
+ import gradio as gr
12
+ import numpy as np
13
+ import torch
14
+ import torch.distributed as dist
15
+ from torchvision.transforms.functional import to_pil_image
16
+
17
+ import models
18
+ from PIL import Image
19
+ from lumina_t2i.transport import create_transport, Sampler
20
+
21
+ description = """
22
+ # Lumina Next Text-to-Image
23
+
24
+ Lumina-Next-T2I is a 2B Next-DiT model with 2B text encoder.
25
+
26
+ Demo current model: `Lumina-Next-T2I`
27
+
28
+ ### <span style='color: red;'>Due to the high volume of access, we have temporarily disabled the resolution extrapolation functionality.
29
+
30
+ ### Additionally, we offer three alternative links for Lumina-T2X access. Try to visit other demo sites. [[demo1](http://106.14.2.150:10022/)] [[demo2](http://106.14.2.150:10023/)]
31
+
32
+ """
33
+
34
+ examples = [
35
+ ["👽🤖👹👻"],
36
+ ["孤舟蓑笠翁"],
37
+ ["两只黄鹂鸣翠柳"],
38
+ ["大漠孤烟直,长河落日圆"],
39
+ ["秋风起兮白云飞,草木黄落兮雁南归"],
40
+ ["도쿄 타워, 최고 품질의 우키요에, 에도 시대"],
41
+ ["味噌ラーメン, 最高品質の浮世絵、江戸時代。"],
42
+ ["東京タワー、最高品質の浮世絵、江戸時代。"],
43
+ ["Astronaut on Mars During sunset"],
44
+ ["Tour de Tokyo, estampes ukiyo-e de la plus haute qualité, période Edo"],
45
+ ["🐔 playing 🏀"],
46
+ ["☃️ with 🌹 in the ❄️"],
47
+ ["🐶 wearing 😎 flying on 🌈 "],
48
+ ["A small 🍎 and 🍊 with 😁 emoji in the Sahara desert"],
49
+ ["Токийская башня, лучшие укиё-э, период Эдо"],
50
+ ["Tokio-Turm, hochwertigste Ukiyo-e, Edo-Zeit"],
51
+ ["A scared cute rabbit in Happy Tree Friends style and punk vibe."], # noqa
52
+ ["A humanoid eagle soldier of the First World War."], # noqa
53
+ ["A cute Christmas mockup on an old wooden industrial desk table with Christmas decorations and bokeh lights in the background."],
54
+ ["A front view of a romantic flower shop in France filled with various blooming flowers including lavenders and roses."],
55
+ ["An old man, portrayed as a retro superhero, stands in the streets of New York City at night"],
56
+ ["many trees are surrounded by a lake in autumn colors, in the style of nature-inspired imagery, havencore, brightly colored, dark white and dark orange, bright primary colors, environmental activism, forestpunk --ar 64:51"],
57
+ ["A fluffy mouse holding a watermelon, in a magical and colorful setting, illustrated in the style of Hayao Miyazaki anime by Studio Ghibli."],
58
+ ["Inka warrior with a war make up, medium shot, natural light, Award winning wildlife photography, hyperrealistic, 8k resolution, --ar 9:16"],
59
+ ["Character of lion in style of saiyan, mafia, gangsta, citylights background, Hyper detailed, hyper realistic, unreal engine ue5, cgi 3d, cinematic shot, 8k"],
60
+ ["In the sky above, a giant, whimsical cloud shaped like the 😊 emoji casts a soft, golden light over the scene"],
61
+ ["Cyberpunk eagle, neon ambiance, abstract black oil, gear mecha, detailed acrylic, grunge, intricate complexity, rendered in unreal engine 5, photorealistic, 8k"],
62
+ ["close-up photo of a beautiful red rose breaking through a cube made of ice , splintered cracked ice surface, frosted colors, blood dripping from rose, melting ice, Valentine’s Day vibes, cinematic, sharp focus, intricate, cinematic, dramatic light"],
63
+ ["3D cartoon Fox Head with Human Body, Wearing Iridescent Holographic Liquid Texture & Translucent Material Sun Protective Shirt, Boss Feel, Nike or Addidas Sun Protective Shirt, WitchPunk, Y2K Style, Green and blue, Blue, Metallic Feel, Strong Reflection, plain background, no background, pure single color background, Digital Fashion, Surreal Futurism, Supreme Kong NFT Artwork Style, disney style, headshot photography for portrait studio shoot, fashion editorial aesthetic, high resolution in the style of HAPE PRIME NFT, NFT 3D IP Feel, Bored Ape Yacht Club NFT project Feel, high detail, fine luster, 3D render, oc render, best quality, 8K, bright, front lighting, Face Shot, fine luster, ultra detailed"],
64
+ ],
65
+
66
+ class ModelFailure:
67
+ pass
68
+
69
+
70
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
71
+ def encode_prompt(
72
+ prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True
73
+ ):
74
+
75
+ captions = []
76
+ for caption in prompt_batch:
77
+ if random.random() < proportion_empty_prompts:
78
+ captions.append("")
79
+ elif isinstance(caption, str):
80
+ captions.append(caption)
81
+ elif isinstance(caption, (list, np.ndarray)):
82
+ # take a random caption if there are multiple
83
+ captions.append(random.choice(caption) if is_train else caption[0])
84
+
85
+ with torch.no_grad():
86
+ text_inputs = tokenizer(
87
+ captions,
88
+ padding=True,
89
+ pad_to_multiple_of=8,
90
+ max_length=256,
91
+ truncation=True,
92
+ return_tensors="pt",
93
+ )
94
+
95
+ text_input_ids = text_inputs.input_ids
96
+ prompt_masks = text_inputs.attention_mask
97
+
98
+ prompt_embeds = text_encoder(
99
+ input_ids=text_input_ids.cuda(),
100
+ attention_mask=prompt_masks.cuda(),
101
+ output_hidden_states=True,
102
+ ).hidden_states[-2]
103
+
104
+ return prompt_embeds, prompt_masks
105
+
106
+
107
+ @torch.no_grad()
108
+ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrier):
109
+ # import here to avoid huggingface Tokenizer parallelism warnings
110
+ from diffusers.models import AutoencoderKL
111
+ from transformers import AutoModelForCausalLM, AutoTokenizer
112
+
113
+ # override the default print function since the delay can be large for child process
114
+ original_print = builtins.print
115
+
116
+ # Redefine the print function with flush=True by default
117
+ def print(*args, **kwargs):
118
+ kwargs.setdefault("flush", True)
119
+ original_print(*args, **kwargs)
120
+
121
+ # Override the built-in print with the new version
122
+ builtins.print = print
123
+
124
+ os.environ["MASTER_PORT"] = str(master_port)
125
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
126
+ os.environ["RANK"] = str(rank)
127
+ os.environ["WORLD_SIZE"] = str(args.num_gpus)
128
+
129
+ dist.init_process_group("nccl")
130
+ # set up fairscale environment because some methods of the Lumina model need it,
131
+ # though for single-GPU inference fairscale actually has no effect
132
+ fs_init.initialize_model_parallel(args.num_gpus)
133
+ torch.cuda.set_device(rank)
134
+
135
+ train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
136
+ if dist.get_rank() == 0:
137
+ print("Loaded model arguments:", json.dumps(train_args.__dict__, indent=2))
138
+
139
+ if dist.get_rank() == 0:
140
+ print(f"Creating lm: Gemma-2B")
141
+
142
+ dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
143
+ args.precision
144
+ ]
145
+
146
+ text_encoder = (
147
+ AutoModelForCausalLM.from_pretrained(
148
+ "google/gemma-2b", torch_dtype=dtype, device_map="cuda"
149
+ )
150
+ .get_decoder()
151
+ .eval()
152
+ )
153
+ cap_feat_dim = text_encoder.config.hidden_size
154
+ if args.num_gpus > 1:
155
+ raise NotImplementedError("Inference with >1 GPUs not yet supported")
156
+
157
+ tokenizer = AutoTokenizer.from_pretrained(
158
+ "google/gemma-2b", add_bos_token=True, add_eos_token=True
159
+ )
160
+ tokenizer.padding_side = "right"
161
+
162
+ if dist.get_rank() == 0:
163
+ print(f"Creating vae: sdxl-vae")
164
+ vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae",
165
+ torch_dtype=torch.float32,
166
+ ).cuda()
167
+
168
+ if dist.get_rank() == 0:
169
+ print(f"Creating DiT: Next-DiT")
170
+ # latent_size = train_args.image_size // 8
171
+ model = models.__dict__["DiT_Llama_2B_patch2"](
172
+ qk_norm=train_args.qk_norm,
173
+ cap_feat_dim=cap_feat_dim,
174
+ )
175
+ model.eval().to("cuda", dtype=dtype)
176
+
177
+ assert train_args.model_parallel_size == args.num_gpus
178
+ if args.ema:
179
+ print("Loading ema model.")
180
+ ckpt = torch.load(
181
+ os.path.join(
182
+ args.ckpt,
183
+ f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.pth",
184
+ ),
185
+ map_location="cpu",
186
+ )
187
+ model.load_state_dict(ckpt, strict=True)
188
+
189
+ mp_barrier.wait()
190
+
191
+ with torch.autocast("cuda", dtype):
192
+ while True:
193
+ (
194
+ cap,
195
+ resolution,
196
+ num_sampling_steps,
197
+ cfg_scale,
198
+ solver,
199
+ t_shift,
200
+ seed,
201
+ ntk_scaling,
202
+ proportional_attn,
203
+ ) = request_queue.get()
204
+
205
+ print(
206
+ "> params:",
207
+ cap,
208
+ resolution,
209
+ num_sampling_steps,
210
+ cfg_scale,
211
+ solver,
212
+ t_shift,
213
+ seed,
214
+ ntk_scaling,
215
+ proportional_attn,
216
+ )
217
+ try:
218
+ # begin sampler
219
+ transport = create_transport(
220
+ args.path_type,
221
+ args.prediction,
222
+ args.loss_weight,
223
+ args.train_eps,
224
+ args.sample_eps,
225
+ )
226
+ sampler = Sampler(transport)
227
+ if args.sampler_mode == "ODE":
228
+ if args.likelihood:
229
+ # assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" # todo
230
+ sample_fn = sampler.sample_ode_likelihood(
231
+ sampling_method=solver,
232
+ num_steps=num_sampling_steps,
233
+ atol=args.atol,
234
+ rtol=args.rtol,
235
+ )
236
+ else:
237
+ sample_fn = sampler.sample_ode(
238
+ sampling_method=solver,
239
+ num_steps=num_sampling_steps,
240
+ atol=args.atol,
241
+ rtol=args.rtol,
242
+ reverse=args.reverse,
243
+ time_shifting_factor=t_shift,
244
+ )
245
+ elif args.sampler_mode == "SDE":
246
+ sample_fn = sampler.sample_sde(
247
+ sampling_method=solver,
248
+ diffusion_form=args.diffusion_form,
249
+ diffusion_norm=args.diffusion_norm,
250
+ last_step=args.last_step,
251
+ last_step_size=args.last_step_size,
252
+ num_steps=num_sampling_steps,
253
+ )
254
+ # end sampler
255
+
256
+ resolution = resolution.split(" ")[-1]
257
+ w, h = resolution.split("x")
258
+ w, h = int(w), int(h)
259
+ latent_w, latent_h = w // 8, h // 8
260
+ if int(seed) != 0:
261
+ torch.random.manual_seed(int(seed))
262
+ z = torch.randn([1, 4, latent_h, latent_w], device="cuda").to(dtype)
263
+ z = z.repeat(2, 1, 1, 1)
264
+
265
+ with torch.no_grad():
266
+ cap_feats, cap_mask = encode_prompt(
267
+ [cap] + [""], text_encoder, tokenizer, 0.0
268
+ )
269
+ cap_mask = cap_mask.to(cap_feats.device)
270
+
271
+ train_res = 1024
272
+ res_cat = (w * h) ** 0.5
273
+ print(f"res_cat: {res_cat}")
274
+ max_seq_len = (res_cat // 16) ** 2 + (res_cat // 16) * 2
275
+ print(f"max_seq_len: {max_seq_len}")
276
+
277
+ rope_scaling_factor = 1.0
278
+ ntk_factor = max_seq_len / (train_res // 16) ** 2
279
+ print(f"ntk_factor: {ntk_factor}")
280
+
281
+ model_kwargs = dict(
282
+ cap_feats=cap_feats,
283
+ cap_mask=cap_mask,
284
+ cfg_scale=cfg_scale,
285
+ rope_scaling_factor=rope_scaling_factor,
286
+ ntk_factor=ntk_factor,
287
+ )
288
+
289
+ if dist.get_rank() == 0:
290
+ print(f"caption: {cap}")
291
+ print(f"num_sampling_steps: {num_sampling_steps}")
292
+ print(f"cfg_scale: {cfg_scale}")
293
+
294
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
295
+ print("> [debug] start sample")
296
+ samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
297
+ samples = samples[:1]
298
+
299
+ factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
300
+ print(f"vae factor: {factor}")
301
+ samples = vae.decode(samples / factor).sample
302
+ samples = (samples + 1.0) / 2.0
303
+ samples.clamp_(0.0, 1.0)
304
+ img = to_pil_image(samples[0].float())
305
+
306
+ if response_queue is not None:
307
+ response_queue.put(img)
308
+
309
+ except Exception:
310
+ print(traceback.format_exc())
311
+ response_queue.put(ModelFailure())
312
+
313
+
314
+ def none_or_str(value):
315
+ if value == "None":
316
+ return None
317
+ return value
318
+
319
+
320
+ def parse_transport_args(parser):
321
+ group = parser.add_argument_group("Transport arguments")
322
+ group.add_argument(
323
+ "--path-type",
324
+ type=str,
325
+ default="Linear",
326
+ choices=["Linear", "GVP", "VP"],
327
+ help="the type of path for transport: 'Linear', 'GVP' (Geodesic Vector Pursuit), or 'VP' (Vector Pursuit).",
328
+ )
329
+ group.add_argument(
330
+ "--prediction",
331
+ type=str,
332
+ default="velocity",
333
+ choices=["velocity", "score", "noise"],
334
+ help="the prediction model for the transport dynamics.",
335
+ )
336
+ group.add_argument(
337
+ "--loss-weight",
338
+ type=none_or_str,
339
+ default=None,
340
+ choices=[None, "velocity", "likelihood"],
341
+ help="the weighting of different components in the loss function, can be 'velocity' for dynamic modeling, 'likelihood' for statistical consistency, or None for no weighting.",
342
+ )
343
+ group.add_argument(
344
+ "--sample-eps", type=float, help="sampling in the transport model."
345
+ )
346
+ group.add_argument(
347
+ "--train-eps", type=float, help="training to stabilize the learning process."
348
+ )
349
+
350
+
351
+ def parse_ode_args(parser):
352
+ group = parser.add_argument_group("ODE arguments")
353
+ group.add_argument(
354
+ "--atol",
355
+ type=float,
356
+ default=1e-6,
357
+ help="Absolute tolerance for the ODE solver.",
358
+ )
359
+ group.add_argument(
360
+ "--rtol",
361
+ type=float,
362
+ default=1e-3,
363
+ help="Relative tolerance for the ODE solver.",
364
+ )
365
+ group.add_argument(
366
+ "--reverse", action="store_true", help="run the ODE solver in reverse."
367
+ )
368
+ group.add_argument(
369
+ "--likelihood",
370
+ action="store_true",
371
+ help="Enable calculation of likelihood during the ODE solving process.",
372
+ )
373
+
374
+
375
+ def parse_sde_args(parser):
376
+ group = parser.add_argument_group("SDE arguments")
377
+ group.add_argument(
378
+ "--sampling-method",
379
+ type=str,
380
+ default="Euler",
381
+ choices=["Euler", "Heun"],
382
+ help="the numerical method used for sampling the stochastic differential equation: 'Euler' for simplicity or 'Heun' for improved accuracy.",
383
+ )
384
+ group.add_argument(
385
+ "--diffusion-form",
386
+ type=str,
387
+ default="sigma",
388
+ choices=[
389
+ "constant",
390
+ "SBDM",
391
+ "sigma",
392
+ "linear",
393
+ "decreasing",
394
+ "increasing-decreasing",
395
+ ],
396
+ help="form of diffusion coefficient in the SDE",
397
+ )
398
+ group.add_argument(
399
+ "--diffusion-norm",
400
+ type=float,
401
+ default=1.0,
402
+ help="Normalizes the diffusion coefficient, affecting the scale of the stochastic component.",
403
+ )
404
+ group.add_argument(
405
+ "--last-step",
406
+ type=none_or_str,
407
+ default="Mean",
408
+ choices=[None, "Mean", "Tweedie", "Euler"],
409
+ help="form of last step taken in the SDE",
410
+ )
411
+ group.add_argument(
412
+ "--last-step-size", type=float, default=0.04, help="size of the last step taken"
413
+ )
414
+
415
+
416
+ def find_free_port() -> int:
417
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
418
+ sock.bind(("", 0))
419
+ port = sock.getsockname()[1]
420
+ sock.close()
421
+ return port
422
+
423
+
424
+ def main():
425
+ parser = argparse.ArgumentParser()
426
+ mode = "ODE"
427
+
428
+ parser.add_argument("--num_gpus", type=int, default=1)
429
+ parser.add_argument("--ckpt", type=str, default="./checkpoints")
430
+ parser.add_argument("--ema", type=bool, default=True)
431
+ parser.add_argument("--precision", default="bf16", choices=["bf16", "fp32"])
432
+
433
+ parse_transport_args(parser)
434
+ if mode == "ODE":
435
+ parse_ode_args(parser)
436
+ # Further processing for ODE
437
+ elif mode == "SDE":
438
+ parse_sde_args(parser)
439
+ # Further processing for SDE
440
+
441
+ args = parser.parse_known_args()[0]
442
+
443
+ if args.num_gpus != 1:
444
+ raise NotImplementedError("Multi-GPU Inference is not yet supported")
445
+
446
+ args.sampler_mode = mode
447
+
448
+ master_port = find_free_port()
449
+
450
+ processes = []
451
+ request_queues = []
452
+ response_queue = mp.Queue()
453
+ mp_barrier = mp.Barrier(args.num_gpus + 1)
454
+ for i in range(args.num_gpus):
455
+ request_queues.append(mp.Queue())
456
+ p = mp.Process(
457
+ target=model_main,
458
+ args=(
459
+ args,
460
+ master_port,
461
+ i,
462
+ request_queues[i],
463
+ response_queue if i == 0 else None,
464
+ mp_barrier,
465
+ ),
466
+ )
467
+ p.start()
468
+ processes.append(p)
469
+
470
+ with gr.Blocks() as demo:
471
+ with gr.Row():
472
+ gr.Markdown(description)
473
+ with gr.Row():
474
+ with gr.Column():
475
+ cap = gr.Textbox(
476
+ lines=2,
477
+ label="Caption",
478
+ interactive=True,
479
+ value="Miss Mexico portrait of the most beautiful mexican woman, Exquisite detail, 30-megapixel, 4k, 85-mm-lens, sharp-focus, f:8, "
480
+ "ISO 100, shutter-speed 1:125, diffuse-back-lighting, award-winning photograph, small-catchlight, High-sharpness, facial-symmetry, 8k --q 2 --ar 18:32 --v 5",
481
+ )
482
+ with gr.Row():
483
+ res_choices = ["1024x1024", "512x2048", "2048x512"] + [
484
+ "(Extrapolation) 1664x1664",
485
+ "(Extrapolation) 1024x2048",
486
+ "(Extrapolation) 2048x1024",
487
+ ]
488
+ resolution = gr.Dropdown(
489
+ value=res_choices[0], choices=res_choices, label="Resolution"
490
+ )
491
+ with gr.Row():
492
+ num_sampling_steps = gr.Slider(
493
+ minimum=1,
494
+ maximum=70,
495
+ value=30,
496
+ interactive=True,
497
+ label="Sampling steps",
498
+ )
499
+ seed = gr.Slider(
500
+ minimum=0,
501
+ maximum=int(1e5),
502
+ value=1,
503
+ step=1,
504
+ interactive=True,
505
+ label="Seed (0 for random)",
506
+ )
507
+ with gr.Accordion(
508
+ "Advanced Settings for Resolution Extrapolation", open=False
509
+ ):
510
+ with gr.Row():
511
+ solver = gr.Dropdown(
512
+ value="euler",
513
+ choices=["euler", "dopri5", "dopri8"],
514
+ label="solver",
515
+ )
516
+ t_shift = gr.Slider(
517
+ minimum=1,
518
+ maximum=20,
519
+ value=6,
520
+ step=1,
521
+ interactive=True,
522
+ label="Time shift",
523
+ )
524
+ cfg_scale = gr.Slider(
525
+ minimum=1.0,
526
+ maximum=20.0,
527
+ value=4.0,
528
+ interactive=True,
529
+ label="CFG scale",
530
+ )
531
+ with gr.Row():
532
+ ntk_scaling = gr.Checkbox(
533
+ value=True,
534
+ interactive=True,
535
+ label="ntk scaling",
536
+ )
537
+ proportional_attn = gr.Checkbox(
538
+ value=True,
539
+ interactive=True,
540
+ label="Proportional attention",
541
+ )
542
+ with gr.Row():
543
+ submit_btn = gr.Button("Submit", variant="primary")
544
+ # reset_btn = gr.ClearButton([
545
+ # cap, resolution,
546
+ # num_sampling_steps, cfg_scale, solver,
547
+ # t_shift, seed,
548
+ # ntk_scaling, proportional_attn
549
+ # ])
550
+ with gr.Column():
551
+ default_img = Image.open("./image.png")
552
+ output_img = gr.Image(
553
+ label="Generated image",
554
+ interactive=False,
555
+ format="png",
556
+ value=default_img,
557
+ )
558
+
559
+ with gr.Row():
560
+ gr.Examples(
561
+ examples,
562
+ [cap],
563
+ label="Examples",
564
+ )
565
+
566
+ def on_submit(*args):
567
+ for q in request_queues:
568
+ q.put(args)
569
+ result = response_queue.get()
570
+ if isinstance(result, ModelFailure):
571
+ raise RuntimeError
572
+ return result
573
+
574
+ submit_btn.click(
575
+ on_submit,
576
+ [
577
+ cap,
578
+ resolution,
579
+ num_sampling_steps,
580
+ cfg_scale,
581
+ solver,
582
+ t_shift,
583
+ seed,
584
+ ntk_scaling,
585
+ proportional_attn,
586
+ ],
587
+ [output_img],
588
+ )
589
+
590
+ mp_barrier.wait()
591
+ demo.queue().launch(share=True, server_name="0.0.0.0")
592
+
593
+
594
+ if __name__ == "__main__":
595
+ os.system("mkdir -p ./checkpoints")
596
+ os.system("huggingface-cli download --resume-download Alpha-VLLM/Lumina-Next-T2I --local-dir ./checkpoints --local-dir-use-symlinks False")
597
+ mp.set_start_method("spawn")
598
+ main()
models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from .model import DiT_Llama_5B_patch2
2
+ from .model import DiT_Llama_2B_patch2
models/components.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ try:
7
+ from apex.normalization import FusedRMSNorm as RMSNorm
8
+ except ImportError:
9
+ warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10
+
11
+ class RMSNorm(torch.nn.Module):
12
+ def __init__(self, dim: int, eps: float = 1e-6):
13
+ """
14
+ Initialize the RMSNorm normalization layer.
15
+
16
+ Args:
17
+ dim (int): The dimension of the input tensor.
18
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19
+
20
+ Attributes:
21
+ eps (float): A small value added to the denominator for numerical stability.
22
+ weight (nn.Parameter): Learnable scaling parameter.
23
+
24
+ """
25
+ super().__init__()
26
+ self.eps = eps
27
+ self.weight = nn.Parameter(torch.ones(dim))
28
+
29
+ def _norm(self, x):
30
+ """
31
+ Apply the RMSNorm normalization to the input tensor.
32
+
33
+ Args:
34
+ x (torch.Tensor): The input tensor.
35
+
36
+ Returns:
37
+ torch.Tensor: The normalized tensor.
38
+
39
+ """
40
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41
+
42
+ def forward(self, x):
43
+ """
44
+ Forward pass through the RMSNorm layer.
45
+
46
+ Args:
47
+ x (torch.Tensor): The input tensor.
48
+
49
+ Returns:
50
+ torch.Tensor: The output tensor after applying RMSNorm.
51
+
52
+ """
53
+ output = self._norm(x.float()).type_as(x)
54
+ return output * self.weight
models/model.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import functools
13
+ import logging
14
+ import math
15
+ from typing import Optional, Tuple, List
16
+
17
+ # from apex.normalization import FusedRMSNorm as RMSNorm
18
+ from .components import RMSNorm
19
+ import fairscale.nn.model_parallel.initialize as fs_init
20
+ from fairscale.nn.model_parallel.layers import (
21
+ ColumnParallelLinear, RowParallelLinear, ParallelEmbedding,
22
+ )
23
+ from flash_attn import flash_attn_varlen_func
24
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
25
+ import torch
26
+ import torch.distributed as dist
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def modulate(x, shift, scale):
34
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
35
+
36
+
37
+ #############################################################################
38
+ # Embedding Layers for Timesteps and Class Labels #
39
+ #############################################################################
40
+
41
+ class ParallelTimestepEmbedder(nn.Module):
42
+ """
43
+ Embeds scalar timesteps into vector representations.
44
+ """
45
+ def __init__(self, hidden_size, frequency_embedding_size=256):
46
+ super().__init__()
47
+ self.mlp = nn.Sequential(
48
+ ColumnParallelLinear(
49
+ frequency_embedding_size, hidden_size, bias=True,
50
+ gather_output=False,
51
+ init_method=functools.partial(nn.init.normal_, std=0.02),
52
+ ),
53
+ nn.SiLU(),
54
+ RowParallelLinear(
55
+ hidden_size, hidden_size, bias=True, input_is_parallel=True,
56
+ init_method=functools.partial(nn.init.normal_, std=0.02),
57
+ ),
58
+ )
59
+ self.frequency_embedding_size = frequency_embedding_size
60
+
61
+ @staticmethod
62
+ def timestep_embedding(t, dim, max_period=10000):
63
+ """
64
+ Create sinusoidal timestep embeddings.
65
+ :param t: a 1-D Tensor of N indices, one per batch element.
66
+ These may be fractional.
67
+ :param dim: the dimension of the output.
68
+ :param max_period: controls the minimum frequency of the embeddings.
69
+ :return: an (N, D) Tensor of positional embeddings.
70
+ """
71
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
72
+ half = dim // 2
73
+ freqs = torch.exp(
74
+ -math.log(max_period) * torch.arange(
75
+ start=0, end=half, dtype=torch.float32
76
+ ) / half
77
+ ).to(device=t.device)
78
+ args = t[:, None].float() * freqs[None]
79
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
80
+ if dim % 2:
81
+ embedding = torch.cat([
82
+ embedding, torch.zeros_like(embedding[:, :1])
83
+ ], dim=-1)
84
+ return embedding
85
+
86
+ def forward(self, t):
87
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
88
+ t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
89
+ return t_emb
90
+
91
+
92
+ class ParallelLabelEmbedder(nn.Module):
93
+ r"""Embeds class labels into vector representations. Also handles label
94
+ dropout for classifier-free guidance.
95
+ """
96
+ def __init__(self, num_classes, hidden_size, dropout_prob):
97
+ super().__init__()
98
+ use_cfg_embedding = int(dropout_prob > 0)
99
+ self.embedding_table = ParallelEmbedding(
100
+ num_classes + use_cfg_embedding, hidden_size,
101
+ init_method=functools.partial(nn.init.normal_, std=0.02),
102
+ )
103
+ self.num_classes = num_classes
104
+ self.dropout_prob = dropout_prob
105
+
106
+ def token_drop(self, labels, force_drop_ids=None):
107
+ """
108
+ Drops labels to enable classifier-free guidance.
109
+ """
110
+ if force_drop_ids is None:
111
+ drop_ids = torch.rand(
112
+ labels.shape[0], device=labels.device
113
+ ) < self.dropout_prob
114
+ drop_ids = drop_ids.cuda()
115
+ dist.broadcast(
116
+ drop_ids,
117
+ fs_init.get_model_parallel_src_rank(),
118
+ fs_init.get_model_parallel_group(),
119
+ )
120
+ drop_ids = drop_ids.to(labels.device)
121
+ else:
122
+ drop_ids = force_drop_ids == 1
123
+ labels = torch.where(drop_ids, self.num_classes, labels)
124
+ return labels
125
+
126
+ def forward(self, labels, train, force_drop_ids=None):
127
+ use_dropout = self.dropout_prob > 0
128
+ if (train and use_dropout) or (force_drop_ids is not None):
129
+ labels = self.token_drop(labels, force_drop_ids)
130
+ embeddings = self.embedding_table(labels)
131
+ return embeddings
132
+
133
+
134
+ #############################################################################
135
+ # Core DiT Model #
136
+ #############################################################################
137
+
138
+
139
+ class Attention(nn.Module):
140
+ """Multi-head attention module."""
141
+ def __init__(self, dim: int, n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, y_dim: int):
142
+ """
143
+ Initialize the Attention module.
144
+
145
+ Args:
146
+ dim (int): Number of input dimensions.
147
+ n_heads (int): Number of heads.
148
+ n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
149
+
150
+ """
151
+ super().__init__()
152
+ self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
153
+ model_parallel_size = fs_init.get_model_parallel_world_size()
154
+ self.n_local_heads = n_heads // model_parallel_size
155
+ self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
156
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
157
+ self.head_dim = dim // n_heads
158
+
159
+ self.wq = ColumnParallelLinear(
160
+ dim, n_heads * self.head_dim, bias=False, gather_output=False,
161
+ init_method=nn.init.xavier_uniform_,
162
+ )
163
+ self.wk = ColumnParallelLinear(
164
+ dim, self.n_kv_heads * self.head_dim, bias=False,
165
+ gather_output=False, init_method=nn.init.xavier_uniform_,
166
+ )
167
+ self.wv = ColumnParallelLinear(
168
+ dim, self.n_kv_heads * self.head_dim, bias=False,
169
+ gather_output=False, init_method=nn.init.xavier_uniform_,
170
+ )
171
+ if y_dim > 0:
172
+ self.wk_y = ColumnParallelLinear(
173
+ y_dim, self.n_kv_heads * self.head_dim, bias=False,
174
+ gather_output=False, init_method=nn.init.xavier_uniform_,
175
+ )
176
+ self.wv_y = ColumnParallelLinear(
177
+ y_dim, self.n_kv_heads * self.head_dim, bias=False,
178
+ gather_output=False, init_method=nn.init.xavier_uniform_,
179
+ )
180
+ self.gate = nn.Parameter(torch.zeros([self.n_local_heads]))
181
+
182
+ self.wo = RowParallelLinear(
183
+ n_heads * self.head_dim, dim, bias=False,
184
+ input_is_parallel=True, init_method=nn.init.xavier_uniform_,
185
+ )
186
+
187
+ if qk_norm:
188
+ self.q_norm = nn.LayerNorm(self.n_local_heads * self.head_dim)
189
+ self.k_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim)
190
+ if y_dim > 0:
191
+ self.ky_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim)
192
+ else:
193
+ self.ky_norm = nn.Identity()
194
+ else:
195
+ self.q_norm = self.k_norm = nn.Identity()
196
+ self.ky_norm = nn.Identity()
197
+
198
+ # for proportional attention computation
199
+ self.base_seqlen = None
200
+ self.proportional_attn = False
201
+
202
+ @staticmethod
203
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
204
+ """
205
+ Reshape frequency tensor for broadcasting it with another tensor.
206
+
207
+ This function reshapes the frequency tensor to have the same shape as
208
+ the target tensor 'x' for the purpose of broadcasting the frequency
209
+ tensor during element-wise operations.
210
+
211
+ Args:
212
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
213
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
214
+
215
+ Returns:
216
+ torch.Tensor: Reshaped frequency tensor.
217
+
218
+ Raises:
219
+ AssertionError: If the frequency tensor doesn't match the expected
220
+ shape.
221
+ AssertionError: If the target tensor 'x' doesn't have the expected
222
+ number of dimensions.
223
+ """
224
+ ndim = x.ndim
225
+ assert 0 <= 1 < ndim
226
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
227
+ shape = [d if i == 1 or i == ndim - 1 else 1
228
+ for i, d in enumerate(x.shape)]
229
+ return freqs_cis.view(*shape)
230
+
231
+ @staticmethod
232
+ def apply_rotary_emb(
233
+ x_in: torch.Tensor,
234
+ freqs_cis: torch.Tensor,
235
+ ) -> torch.Tensor:
236
+ """
237
+ Apply rotary embeddings to input tensors using the given frequency
238
+ tensor.
239
+
240
+ This function applies rotary embeddings to the given query 'xq' and
241
+ key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
242
+ input tensors are reshaped as complex numbers, and the frequency tensor
243
+ is reshaped for broadcasting compatibility. The resulting tensors
244
+ contain rotary embeddings and are returned as real tensors.
245
+
246
+ Args:
247
+ x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
248
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
249
+ exponentials.
250
+
251
+ Returns:
252
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
253
+ and key tensor with rotary embeddings.
254
+ """
255
+ with torch.cuda.amp.autocast(enabled=False):
256
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
257
+ freqs_cis = freqs_cis.unsqueeze(2)
258
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
259
+ return x_out.type_as(x_in)
260
+
261
+ # copied from huggingface modeling_llama.py
262
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
263
+
264
+ def _get_unpad_data(attention_mask):
265
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
266
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
267
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
268
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
269
+ return (
270
+ indices,
271
+ cu_seqlens,
272
+ max_seqlen_in_batch,
273
+ )
274
+
275
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
276
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
277
+
278
+ key_layer = index_first_axis(
279
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
280
+ )
281
+ value_layer = index_first_axis(
282
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
283
+ )
284
+ if query_length == kv_seq_len:
285
+ query_layer = index_first_axis(
286
+ query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim), indices_k
287
+ )
288
+ cu_seqlens_q = cu_seqlens_k
289
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
290
+ indices_q = indices_k
291
+ elif query_length == 1:
292
+ max_seqlen_in_batch_q = 1
293
+ cu_seqlens_q = torch.arange(
294
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
295
+ ) # There is a memcpy here, that is very bad.
296
+ indices_q = cu_seqlens_q[:-1]
297
+ query_layer = query_layer.squeeze(1)
298
+ else:
299
+ # The -q_len: slice assumes left padding.
300
+ attention_mask = attention_mask[:, -query_length:]
301
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
302
+
303
+ return (
304
+ query_layer,
305
+ key_layer,
306
+ value_layer,
307
+ indices_q,
308
+ (cu_seqlens_q, cu_seqlens_k),
309
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
310
+ )
311
+
312
+ def forward(
313
+ self,
314
+ x: torch.Tensor,
315
+ x_mask: torch.Tensor,
316
+ freqs_cis: torch.Tensor,
317
+ y: torch.Tensor,
318
+ y_mask: torch.Tensor,
319
+ ) -> torch.Tensor:
320
+ """
321
+
322
+ Args:
323
+ x:
324
+ x_mask:
325
+ freqs_cis:
326
+ y:
327
+ y_mask:
328
+
329
+ Returns:
330
+
331
+ """
332
+ bsz, seqlen, _ = x.shape
333
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
334
+ dtype = xq.dtype
335
+
336
+ xq = self.q_norm(xq)
337
+ xk = self.k_norm(xk)
338
+
339
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
340
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
341
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
342
+
343
+ xq = Attention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
344
+ xk = Attention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
345
+
346
+ xq, xk = xq.to(dtype), xk.to(dtype)
347
+
348
+ if dtype in [torch.float16, torch.bfloat16]:
349
+ # begin var_len flash attn
350
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
351
+ xq, xk, xv, x_mask, seqlen
352
+ )
353
+
354
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
355
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
356
+
357
+ if self.proportional_attn:
358
+ softmax_scale = math.sqrt(math.log(seqlen, self.base_seqlen) / self.head_dim)
359
+ else:
360
+ softmax_scale = math.sqrt(1 / self.head_dim)
361
+
362
+ attn_output_unpad = flash_attn_varlen_func(
363
+ query_states,
364
+ key_states,
365
+ value_states,
366
+ cu_seqlens_q=cu_seqlens_q,
367
+ cu_seqlens_k=cu_seqlens_k,
368
+ max_seqlen_q=max_seqlen_in_batch_q,
369
+ max_seqlen_k=max_seqlen_in_batch_k,
370
+ dropout_p=0.,
371
+ causal=False,
372
+ softmax_scale=softmax_scale
373
+ )
374
+ output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
375
+ # end var_len_flash_attn
376
+
377
+ else:
378
+ output = F.scaled_dot_product_attention(
379
+ xq.permute(0, 2, 1, 3),
380
+ xk.permute(0, 2, 1, 3),
381
+ xv.permute(0, 2, 1, 3),
382
+ attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
383
+ ).permute(0, 2, 1, 3).to(dtype)
384
+
385
+ if hasattr(self, "wk_y"):
386
+ # todo better flash_attn support
387
+ yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_local_kv_heads, self.head_dim)
388
+ yv = self.wv_y(y).view(bsz, -1, self.n_local_kv_heads, self.head_dim)
389
+ n_rep = self.n_local_heads // self.n_local_kv_heads
390
+ if n_rep >= 1:
391
+ yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
392
+ yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
393
+ output_y = F.scaled_dot_product_attention(
394
+ xq.permute(0, 2, 1, 3),
395
+ yk.permute(0, 2, 1, 3),
396
+ yv.permute(0, 2, 1, 3),
397
+ y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seqlen, -1)
398
+ ).permute(0, 2, 1, 3)
399
+ output_y = output_y * self.gate.tanh().view(1, 1, -1, 1)
400
+ output = output + output_y
401
+
402
+ output = output.flatten(-2)
403
+
404
+ return self.wo(output)
405
+
406
+
407
+ class FeedForward(nn.Module):
408
+ def __init__(
409
+ self,
410
+ dim: int,
411
+ hidden_dim: int,
412
+ multiple_of: int,
413
+ ffn_dim_multiplier: Optional[float],
414
+ ):
415
+ """
416
+ Initialize the FeedForward module.
417
+
418
+ Args:
419
+ dim (int): Input dimension.
420
+ hidden_dim (int): Hidden dimension of the feedforward layer.
421
+ multiple_of (int): Value to ensure hidden dimension is a multiple
422
+ of this value.
423
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden
424
+ dimension. Defaults to None.
425
+
426
+ Attributes:
427
+ w1 (ColumnParallelLinear): Linear transformation for the first
428
+ layer.
429
+ w2 (RowParallelLinear): Linear transformation for the second layer.
430
+ w3 (ColumnParallelLinear): Linear transformation for the third
431
+ layer.
432
+
433
+ """
434
+ super().__init__()
435
+ hidden_dim = int(2 * hidden_dim / 3)
436
+ # custom dim factor multiplier
437
+ if ffn_dim_multiplier is not None:
438
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
439
+ hidden_dim = multiple_of * (
440
+ (hidden_dim + multiple_of - 1) // multiple_of
441
+ )
442
+
443
+ self.w1 = ColumnParallelLinear(
444
+ dim, hidden_dim, bias=False, gather_output=False,
445
+ init_method=nn.init.xavier_uniform_,
446
+ )
447
+ self.w2 = RowParallelLinear(
448
+ hidden_dim, dim, bias=False, input_is_parallel=True,
449
+ init_method=nn.init.xavier_uniform_,
450
+ )
451
+ self.w3 = ColumnParallelLinear(
452
+ dim, hidden_dim, bias=False, gather_output=False,
453
+ init_method=nn.init.xavier_uniform_,
454
+ )
455
+
456
+ # @torch.compile
457
+ def _forward_silu_gating(self, x1, x3):
458
+ return F.silu(x1) * x3
459
+
460
+ def forward(self, x):
461
+ return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
462
+
463
+
464
+ class TransformerBlock(nn.Module):
465
+ def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int,
466
+ multiple_of: int, ffn_dim_multiplier: float, norm_eps: float,
467
+ qk_norm: bool, y_dim: int) -> None:
468
+ """
469
+ Initialize a TransformerBlock.
470
+
471
+ Args:
472
+ layer_id (int): Identifier for the layer.
473
+ dim (int): Embedding dimension of the input features.
474
+ n_heads (int): Number of attention heads.
475
+ n_kv_heads (Optional[int]): Number of attention heads in key and
476
+ value features (if using GQA), or set to None for the same as
477
+ query.
478
+ multiple_of (int):
479
+ ffn_dim_multiplier (float):
480
+ norm_eps (float):
481
+
482
+ Attributes:
483
+ n_heads (int): Number of attention heads.
484
+ dim (int): Dimension size of the model.
485
+ head_dim (int): Dimension size of each attention head.
486
+ attention (Attention): Attention module.
487
+ feed_forward (FeedForward): FeedForward module.
488
+ layer_id (int): Identifier for the layer.
489
+ attention_norm (RMSNorm): Layer normalization for attention output.
490
+ ffn_norm (RMSNorm): Layer normalization for feedforward output.
491
+
492
+ """
493
+ super().__init__()
494
+ self.dim = dim
495
+ self.head_dim = dim // n_heads
496
+ self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim)
497
+ self.feed_forward = FeedForward(
498
+ dim=dim, hidden_dim=4 * dim, multiple_of=multiple_of,
499
+ ffn_dim_multiplier=ffn_dim_multiplier,
500
+ )
501
+ self.layer_id = layer_id
502
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
503
+ self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
504
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
505
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
506
+
507
+ self.adaLN_modulation = nn.Sequential(
508
+ nn.SiLU(),
509
+ ColumnParallelLinear(
510
+ min(dim, 1024), 6 * dim, bias=True, gather_output=True,
511
+ init_method=nn.init.zeros_,
512
+ ),
513
+ )
514
+
515
+ self.attention_y_norm = RMSNorm(y_dim, eps=norm_eps)
516
+
517
+ def forward(
518
+ self,
519
+ x: torch.Tensor,
520
+ x_mask: torch.Tensor,
521
+ freqs_cis: torch.Tensor,
522
+ y: torch.Tensor,
523
+ y_mask: torch.Tensor,
524
+ adaln_input: Optional[torch.Tensor] = None,
525
+ ):
526
+ """
527
+ Perform a forward pass through the TransformerBlock.
528
+
529
+ Args:
530
+ x (torch.Tensor): Input tensor.
531
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
532
+
533
+ Returns:
534
+ torch.Tensor: Output tensor after applying attention and
535
+ feedforward layers.
536
+
537
+ """
538
+ if adaln_input is not None:
539
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
540
+ self.adaLN_modulation(adaln_input).chunk(6, dim=1)
541
+
542
+ x = x + self.attention_norm1(gate_msa.unsqueeze(1) * self.attention(
543
+ modulate(self.attention_norm(x), shift_msa, scale_msa),
544
+ x_mask,
545
+ freqs_cis,
546
+ self.attention_y_norm(y),
547
+ y_mask,
548
+ ))
549
+ d = x.shape[-1]
550
+ x = x + self.ffn_norm1(gate_mlp.unsqueeze(1) * self.feed_forward(
551
+ modulate(self.ffn_norm(x), shift_mlp, scale_mlp).view(-1, d),
552
+ ).view(*x.shape))
553
+
554
+ else:
555
+ x = x + self.attention_norm1(self.attention(
556
+ self.attention_norm(x), x_mask, freqs_cis, self.attention_y_norm(y), y_mask
557
+ ))
558
+ # for compatibility with torch.compile because the sequence length changes
559
+ B, L, D = x.shape
560
+ x = x.view(B*L, D)
561
+ x = x + self.ffn_norm1(self.feed_forward(self.ffn_norm(x)))
562
+ x = x.view(B, L, D)
563
+
564
+ return x
565
+
566
+
567
+ class ParallelFinalLayer(nn.Module):
568
+ """
569
+ The final layer of DiT.
570
+ """
571
+ def __init__(self, hidden_size, patch_size, out_channels):
572
+ super().__init__()
573
+ self.norm_final = nn.LayerNorm(
574
+ hidden_size, elementwise_affine=False, eps=1e-6,
575
+ )
576
+ self.linear = ColumnParallelLinear(
577
+ hidden_size, patch_size * patch_size * out_channels, bias=True,
578
+ init_method=nn.init.zeros_, gather_output=True,
579
+ )
580
+ self.adaLN_modulation = nn.Sequential(
581
+ nn.SiLU(),
582
+ ColumnParallelLinear(
583
+ min(hidden_size, 1024), 2 * hidden_size, bias=True,
584
+ init_method=nn.init.zeros_, gather_output=True,
585
+ ),
586
+ )
587
+
588
+ def forward(self, x, c):
589
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
590
+ x = modulate(self.norm_final(x), shift, scale)
591
+ x = self.linear(x)
592
+ return x
593
+
594
+
595
+ class DiT_Llama(nn.Module):
596
+ """
597
+ Diffusion model with a Transformer backbone.
598
+ """
599
+ def __init__(
600
+ self,
601
+ patch_size: int = 2,
602
+ in_channels: int = 4,
603
+ dim: int = 4096,
604
+ n_layers: int = 32,
605
+ n_heads: int = 32,
606
+ n_kv_heads: Optional[int] = None,
607
+ multiple_of: int = 256,
608
+ ffn_dim_multiplier: Optional[float] = None,
609
+ norm_eps: float = 1e-5,
610
+ learn_sigma: bool = True,
611
+ qk_norm: bool = False,
612
+ cap_feat_dim: int = 5120,
613
+ rope_scaling_factor: float = 1.,
614
+ ntk_factor: float=1.
615
+ ) -> None:
616
+ super().__init__()
617
+ self.learn_sigma = learn_sigma
618
+ self.in_channels = in_channels
619
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
620
+ self.patch_size = patch_size
621
+
622
+ self.x_embedder = ColumnParallelLinear(
623
+ in_features=patch_size * patch_size * in_channels,
624
+ out_features=dim,
625
+ bias=True,
626
+ gather_output=True,
627
+ init_method=nn.init.xavier_uniform_,
628
+ )
629
+ nn.init.constant_(self.x_embedder.bias, 0.)
630
+
631
+ self.t_embedder = ParallelTimestepEmbedder(min(dim, 1024))
632
+ self.cap_embedder = nn.Sequential(
633
+ nn.LayerNorm(cap_feat_dim),
634
+ ColumnParallelLinear(cap_feat_dim, min(dim, 1024), bias=True, gather_output=True,
635
+ init_method=nn.init.zeros_),
636
+ )
637
+
638
+ self.layers = nn.ModuleList([
639
+ TransformerBlock(layer_id, dim, n_heads, n_kv_heads, multiple_of,
640
+ ffn_dim_multiplier, norm_eps, qk_norm, cap_feat_dim)
641
+ for layer_id in range(n_layers)
642
+ ])
643
+ self.final_layer = ParallelFinalLayer(dim, patch_size, self.out_channels)
644
+
645
+ assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
646
+ self.dim = dim
647
+ self.n_heads = n_heads
648
+ self.freqs_cis = DiT_Llama.precompute_freqs_cis(
649
+ dim // n_heads, 384, rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
650
+ )
651
+ self.rope_scaling_factor = rope_scaling_factor
652
+ self.ntk_factor = ntk_factor
653
+ # self.eol_token = nn.Parameter(torch.empty(dim))
654
+ self.pad_token = nn.Parameter(torch.empty(dim))
655
+ # nn.init.normal_(self.eol_token, std=0.02)
656
+ nn.init.normal_(self.pad_token, std=0.02)
657
+
658
+ def unpatchify(self, x: torch.Tensor, img_size: List[Tuple[int, int]], return_tensor=False) -> List[torch.Tensor]:
659
+ """
660
+ x: (N, T, patch_size**2 * C)
661
+ imgs: (N, H, W, C)
662
+ """
663
+ pH = pW = self.patch_size
664
+ if return_tensor:
665
+ H, W = img_size[0]
666
+ B = x.size(0)
667
+ L = (H // pH) * (W // pW)
668
+ x = x[:, :L].view(B, H // pH, W // pW, pH, pW, self.out_channels)
669
+ x = x.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
670
+ return x
671
+ else:
672
+ imgs = []
673
+ for i in range(x.size(0)):
674
+ H, W = img_size[i]
675
+ L = (H // pH) * (W // pW)
676
+ imgs.append(x[i][:L].view(
677
+ H // pH, W // pW, pH, pW, self.out_channels
678
+ ).permute(4, 0, 2, 1, 3).flatten(3, 4).flatten(1, 2))
679
+ return imgs
680
+
681
+ def patchify_and_embed(
682
+ self,
683
+ x: List[torch.Tensor] | torch.Tensor
684
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]:
685
+ self.freqs_cis = self.freqs_cis.to(x[0].device)
686
+ if isinstance(x, torch.Tensor):
687
+ pH = pW = self.patch_size
688
+ B, C, H, W = x.size()
689
+ x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3)
690
+ x = self.x_embedder(x)
691
+ x = x.flatten(1, 2)
692
+
693
+ mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
694
+ # leave the first line for text
695
+ return x, mask, [(H, W)] * B, self.freqs_cis[:H//pH, :W//pW].flatten(0,1).unsqueeze(0)
696
+ else:
697
+ pH = pW = self.patch_size
698
+ x_embed = []
699
+ freqs_cis = []
700
+ img_size = []
701
+ l_effective_seq_len = []
702
+
703
+ for img in x:
704
+ C, H, W = img.size()
705
+ item_freqs_cis = self.freqs_cis[:H//pH, :W//pW]
706
+ freqs_cis.append(item_freqs_cis.flatten(0,1))
707
+ img_size.append((H, W))
708
+ img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 0, 2, 4).flatten(2)
709
+ img = self.x_embedder(img)
710
+ img = img.flatten(0, 1)
711
+ l_effective_seq_len.append(len(img))
712
+ x_embed.append(img)
713
+
714
+ max_seq_len = max(l_effective_seq_len)
715
+ mask = torch.zeros(len(x), max_seq_len, dtype=torch.int32, device=x[0].device)
716
+ padded_x_embed = []
717
+ padded_freqs_cis = []
718
+ for i, (item_embed, item_freqs_cis, item_seq_len) in enumerate(zip(
719
+ x_embed, freqs_cis, l_effective_seq_len
720
+ )):
721
+ item_embed = torch.cat([
722
+ item_embed,
723
+ self.pad_token.view(1, -1).expand(max_seq_len - item_seq_len, -1),
724
+ ], dim=0)
725
+ item_freqs_cis = torch.cat([
726
+ item_freqs_cis,
727
+ item_freqs_cis[-1:].expand(max_seq_len - item_seq_len, -1)
728
+ ], dim=0)
729
+ padded_x_embed.append(item_embed)
730
+ padded_freqs_cis.append(item_freqs_cis)
731
+ mask[i][:item_seq_len] = 1
732
+
733
+ x_embed = torch.stack(padded_x_embed, dim=0)
734
+ freqs_cis = torch.stack(padded_freqs_cis, dim=0)
735
+ return x_embed, mask, img_size, freqs_cis
736
+
737
+ def forward(self, x, t, cap_feats, cap_mask):
738
+ """
739
+ Forward pass of DiT.
740
+ t: (N,) tensor of diffusion timesteps
741
+ y: (N,) tensor of class labels
742
+ """
743
+ x_is_tensor = isinstance(x, torch.Tensor)
744
+ x, mask, img_size, freqs_cis = self.patchify_and_embed(x)
745
+ freqs_cis = freqs_cis.to(x.device)
746
+
747
+ # cap_freqs_cis = self.freqs_cis[:1, :cap_feats.shape[1]].to(x.device)
748
+
749
+ t = self.t_embedder(t) # (N, D)
750
+ cap_mask_float = cap_mask.float().unsqueeze(-1)
751
+ cap_feats_pool = (cap_feats * cap_mask_float).sum(dim=1) / cap_mask_float.sum(dim=1)
752
+ cap_feats_pool = cap_feats_pool.to(cap_feats)
753
+ cap_emb = self.cap_embedder(cap_feats_pool)
754
+ adaln_input = t + cap_emb
755
+
756
+ cap_mask = cap_mask.bool()
757
+ for layer in self.layers:
758
+ x = layer(
759
+ x, mask, freqs_cis, cap_feats, cap_mask,
760
+ adaln_input=adaln_input
761
+ )
762
+
763
+ x = self.final_layer(x, adaln_input)
764
+ x = self.unpatchify(x, img_size, return_tensor=x_is_tensor)
765
+ if self.learn_sigma:
766
+ if x_is_tensor:
767
+ x, _ = x.chunk(2, dim=1)
768
+ else:
769
+ x = [_.chunk(2, dim=0)[0] for _ in x]
770
+ return x
771
+
772
+ def forward_with_cfg(self, x, t, cap_feats, cap_mask, cfg_scale, rope_scaling_factor=None, ntk_factor=None, base_seqlen: Optional[int] = None, proportional_attn: bool = False):
773
+ # """
774
+ # Forward pass of DiT, but also batches the unconditional forward pass
775
+ # for classifier-free guidance.
776
+ # """
777
+ # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
778
+ # print(ntk_factor, rope_scaling_factor, self.ntk_factor, self.rope_scaling_factor)
779
+ if rope_scaling_factor is not None or ntk_factor is not None:
780
+ rope_scaling_factor = rope_scaling_factor if rope_scaling_factor is not None else self.rope_scaling_factor
781
+ ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
782
+ if rope_scaling_factor != self.rope_scaling_factor or ntk_factor != self.ntk_factor:
783
+ print(f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}", flush=True)
784
+ self.freqs_cis = DiT_Llama.precompute_freqs_cis(
785
+ self.dim // self.n_heads, 384,
786
+ rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
787
+ )
788
+ self.rope_scaling_factor = rope_scaling_factor
789
+ self.ntk_factor = ntk_factor
790
+
791
+ if proportional_attn:
792
+ assert base_seqlen is not None
793
+ for layer in self.layers:
794
+ layer.attention.base_seqlen = base_seqlen
795
+ layer.attention.proportional_attn = proportional_attn
796
+ else:
797
+ for layer in self.layers:
798
+ layer.attention.base_seqlen = None
799
+ layer.attention.proportional_attn = proportional_attn
800
+
801
+ half = x[: len(x) // 2]
802
+ combined = torch.cat([half, half], dim=0)
803
+ model_out = self.forward(combined, t, cap_feats, cap_mask)
804
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
805
+ # three channels by default. The standard approach to cfg applies it to all channels.
806
+ # This can be done by uncommenting the following line and commenting-out the line following that.
807
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
808
+ eps, rest = model_out[:, :3], model_out[:, 3:]
809
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
810
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
811
+ eps = torch.cat([half_eps, half_eps], dim=0)
812
+ return torch.cat([eps, rest], dim=1)
813
+
814
+ @staticmethod
815
+ def precompute_freqs_cis(
816
+ dim: int,
817
+ end: int,
818
+ theta: float = 10000.0,
819
+ rope_scaling_factor: float = 1.0,
820
+ ntk_factor: float = 1.0
821
+ ):
822
+ """
823
+ Precompute the frequency tensor for complex exponentials (cis) with
824
+ given dimensions.
825
+
826
+ This function calculates a frequency tensor with complex exponentials
827
+ using the given dimension 'dim' and the end index 'end'. The 'theta'
828
+ parameter scales the frequencies. The returned tensor contains complex
829
+ values in complex64 data type.
830
+
831
+ Args:
832
+ dim (int): Dimension of the frequency tensor.
833
+ end (int): End index for precomputing frequencies.
834
+ theta (float, optional): Scaling factor for frequency computation.
835
+ Defaults to 10000.0.
836
+
837
+ Returns:
838
+ torch.Tensor: Precomputed frequency tensor with complex
839
+ exponentials.
840
+ """
841
+
842
+ theta = theta * ntk_factor
843
+
844
+ logger.info(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}")
845
+ freqs = 1.0 / (theta ** (
846
+ torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim
847
+ ))
848
+ t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
849
+ t = t / rope_scaling_factor
850
+ freqs = torch.outer(t, freqs).float() # type: ignore
851
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
852
+
853
+ freqs_cis_h = freqs_cis.view(end, 1, dim//4, 1).repeat(1, end, 1, 1)
854
+ freqs_cis_w = freqs_cis.view(1, end, dim//4, 1).repeat(end, 1, 1, 1)
855
+ freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2)
856
+ return freqs_cis
857
+
858
+ def parameter_count(self) -> int:
859
+ tensor_parallel_module_list = (
860
+ ColumnParallelLinear, RowParallelLinear, ParallelEmbedding,
861
+ )
862
+ total_params = 0
863
+
864
+ def _recursive_count_params(module):
865
+ nonlocal total_params
866
+ is_tp_module = isinstance(module, tensor_parallel_module_list)
867
+ for param in module.parameters(recurse=False):
868
+ total_params += param.numel() * (
869
+ fs_init.get_model_parallel_world_size()
870
+ if is_tp_module else 1
871
+ )
872
+ for submodule in module.children():
873
+ _recursive_count_params(submodule)
874
+
875
+ _recursive_count_params(self)
876
+ return total_params
877
+
878
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
879
+ return list(self.layers)
880
+
881
+
882
+ #############################################################################
883
+ # DiT Configs #
884
+ #############################################################################
885
+
886
+
887
+ def DiT_Llama_600M_patch2(**kwargs):
888
+ return DiT_Llama(
889
+ patch_size=2, dim=1536, n_layers=16, n_heads=32, **kwargs
890
+ )
891
+
892
+
893
+ def DiT_Llama_2B_patch2(**kwargs):
894
+ return DiT_Llama(
895
+ patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs
896
+ )
897
+
898
+
899
+ def DiT_Llama_3B_patch2(**kwargs):
900
+ return DiT_Llama(
901
+ patch_size=2, dim=3072, n_layers=32, n_heads=32, **kwargs
902
+ )
903
+
904
+
905
+ def DiT_Llama_7B_patch2(**kwargs):
906
+ return DiT_Llama(
907
+ patch_size=2, dim=4096, n_layers=32, n_heads=32, **kwargs
908
+ )
models/model_5b.py ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import functools
13
+ import math
14
+ from typing import Optional, Tuple, List
15
+
16
+ from .components import RMSNorm
17
+ import fairscale.nn.model_parallel.initialize as fs_init
18
+ from fairscale.nn.model_parallel.layers import (
19
+ ColumnParallelLinear, RowParallelLinear, ParallelEmbedding,
20
+ )
21
+ from flash_attn import flash_attn_varlen_func
22
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
23
+ import torch
24
+ import torch.distributed as dist
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+
29
+ def modulate(x, shift, scale):
30
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
31
+
32
+
33
+ #############################################################################
34
+ # Embedding Layers for Timesteps and Class Labels #
35
+ #############################################################################
36
+
37
+ class ParallelTimestepEmbedder(nn.Module):
38
+ """
39
+ Embeds scalar timesteps into vector representations.
40
+ """
41
+ def __init__(self, hidden_size, frequency_embedding_size=256):
42
+ super().__init__()
43
+ self.mlp = nn.Sequential(
44
+ ColumnParallelLinear(
45
+ frequency_embedding_size, hidden_size, bias=True,
46
+ gather_output=False,
47
+ init_method=functools.partial(nn.init.normal_, std=0.02),
48
+ ),
49
+ nn.SiLU(),
50
+ RowParallelLinear(
51
+ hidden_size, hidden_size, bias=True, input_is_parallel=True,
52
+ init_method=functools.partial(nn.init.normal_, std=0.02),
53
+ ),
54
+ )
55
+ self.frequency_embedding_size = frequency_embedding_size
56
+
57
+ @staticmethod
58
+ def timestep_embedding(t, dim, max_period=10000):
59
+ """
60
+ Create sinusoidal timestep embeddings.
61
+ :param t: a 1-D Tensor of N indices, one per batch element.
62
+ These may be fractional.
63
+ :param dim: the dimension of the output.
64
+ :param max_period: controls the minimum frequency of the embeddings.
65
+ :return: an (N, D) Tensor of positional embeddings.
66
+ """
67
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
68
+ half = dim // 2
69
+ freqs = torch.exp(
70
+ -math.log(max_period) * torch.arange(
71
+ start=0, end=half, dtype=torch.float32
72
+ ) / half
73
+ ).to(device=t.device)
74
+ args = t[:, None].float() * freqs[None]
75
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
76
+ if dim % 2:
77
+ embedding = torch.cat([
78
+ embedding, torch.zeros_like(embedding[:, :1])
79
+ ], dim=-1)
80
+ return embedding
81
+
82
+ def forward(self, t):
83
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
84
+ t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
85
+ return t_emb
86
+
87
+
88
+ class ParallelLabelEmbedder(nn.Module):
89
+ r"""Embeds class labels into vector representations. Also handles label
90
+ dropout for classifier-free guidance.
91
+ """
92
+ def __init__(self, num_classes, hidden_size, dropout_prob):
93
+ super().__init__()
94
+ use_cfg_embedding = int(dropout_prob > 0)
95
+ self.embedding_table = ParallelEmbedding(
96
+ num_classes + use_cfg_embedding, hidden_size,
97
+ init_method=functools.partial(nn.init.normal_, std=0.02),
98
+ )
99
+ self.num_classes = num_classes
100
+ self.dropout_prob = dropout_prob
101
+
102
+ def token_drop(self, labels, force_drop_ids=None):
103
+ """
104
+ Drops labels to enable classifier-free guidance.
105
+ """
106
+ if force_drop_ids is None:
107
+ drop_ids = torch.rand(
108
+ labels.shape[0], device=labels.device
109
+ ) < self.dropout_prob
110
+ drop_ids = drop_ids.cuda()
111
+ dist.broadcast(
112
+ drop_ids,
113
+ fs_init.get_model_parallel_src_rank(),
114
+ fs_init.get_model_parallel_group(),
115
+ )
116
+ drop_ids = drop_ids.to(labels.device)
117
+ else:
118
+ drop_ids = force_drop_ids == 1
119
+ labels = torch.where(drop_ids, self.num_classes, labels)
120
+ return labels
121
+
122
+ def forward(self, labels, train, force_drop_ids=None):
123
+ use_dropout = self.dropout_prob > 0
124
+ if (train and use_dropout) or (force_drop_ids is not None):
125
+ labels = self.token_drop(labels, force_drop_ids)
126
+ embeddings = self.embedding_table(labels)
127
+ return embeddings
128
+
129
+
130
+ #############################################################################
131
+ # Core DiT Model #
132
+ #############################################################################
133
+
134
+
135
+ class Attention(nn.Module):
136
+ """Multi-head attention module."""
137
+ def __init__(self, dim: int, n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, y_dim: int):
138
+ """
139
+ Initialize the Attention module.
140
+
141
+ Args:
142
+ dim (int): Number of input dimensions.
143
+ n_heads (int): Number of heads.
144
+ n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
145
+
146
+ Attributes:
147
+ n_kv_heads (int): Number of key and value heads.
148
+ n_local_heads (int): Number of local query heads.
149
+ n_local_kv_heads (int): Number of local key and value heads.
150
+ n_rep (int): Number of repetitions for local heads.
151
+ head_dim (int): Dimension size of each attention head.
152
+ wq (ColumnParallelLinear): Linear transformation for queries.
153
+ wk (ColumnParallelLinear): Linear transformation for keys.
154
+ wv (ColumnParallelLinear): Linear transformation for values.
155
+ wo (RowParallelLinear): Linear transformation for output.
156
+ cache_k (torch.Tensor): Cached keys for attention.
157
+ cache_v (torch.Tensor): Cached values for attention.
158
+
159
+ """
160
+ super().__init__()
161
+ self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
162
+ model_parallel_size = fs_init.get_model_parallel_world_size()
163
+ self.n_local_heads = n_heads // model_parallel_size
164
+ self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
165
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
166
+ self.head_dim = dim // n_heads
167
+
168
+ self.wq = ColumnParallelLinear(
169
+ dim, n_heads * self.head_dim, bias=False, gather_output=False,
170
+ init_method=nn.init.xavier_uniform_,
171
+ )
172
+ self.wk = ColumnParallelLinear(
173
+ dim, self.n_kv_heads * self.head_dim, bias=False,
174
+ gather_output=False, init_method=nn.init.xavier_uniform_,
175
+ )
176
+ self.wv = ColumnParallelLinear(
177
+ dim, self.n_kv_heads * self.head_dim, bias=False,
178
+ gather_output=False, init_method=nn.init.xavier_uniform_,
179
+ )
180
+ if y_dim > 0:
181
+ self.wk_y = ColumnParallelLinear(
182
+ y_dim, self.n_kv_heads * self.head_dim, bias=False,
183
+ gather_output=False, init_method=nn.init.xavier_uniform_,
184
+ )
185
+ self.wv_y = ColumnParallelLinear(
186
+ y_dim, self.n_kv_heads * self.head_dim, bias=False,
187
+ gather_output=False, init_method=nn.init.xavier_uniform_,
188
+ )
189
+ self.gate = nn.Parameter(torch.zeros([self.n_local_heads]))
190
+
191
+ self.wo = RowParallelLinear(
192
+ n_heads * self.head_dim, dim, bias=False,
193
+ input_is_parallel=True, init_method=nn.init.xavier_uniform_,
194
+ )
195
+
196
+ if qk_norm:
197
+ self.q_norm = nn.LayerNorm(self.n_local_heads * self.head_dim)
198
+ self.k_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim)
199
+ if y_dim > 0:
200
+ self.ky_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim)
201
+ else:
202
+ self.ky_norm = nn.Identity()
203
+ else:
204
+ self.q_norm = self.k_norm = nn.Identity()
205
+ self.ky_norm = nn.Identity()
206
+
207
+ # for proportional attention computation
208
+ self.base_seqlen = None
209
+ self.proportional_attn = False
210
+
211
+ @staticmethod
212
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
213
+ """
214
+ Reshape frequency tensor for broadcasting it with another tensor.
215
+
216
+ This function reshapes the frequency tensor to have the same shape as
217
+ the target tensor 'x' for the purpose of broadcasting the frequency
218
+ tensor during element-wise operations.
219
+
220
+ Args:
221
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
222
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
223
+
224
+ Returns:
225
+ torch.Tensor: Reshaped frequency tensor.
226
+
227
+ Raises:
228
+ AssertionError: If the frequency tensor doesn't match the expected
229
+ shape.
230
+ AssertionError: If the target tensor 'x' doesn't have the expected
231
+ number of dimensions.
232
+ """
233
+ ndim = x.ndim
234
+ assert 0 <= 1 < ndim
235
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
236
+ shape = [d if i == 1 or i == ndim - 1 else 1
237
+ for i, d in enumerate(x.shape)]
238
+ return freqs_cis.view(*shape)
239
+
240
+ @staticmethod
241
+ def apply_rotary_emb(
242
+ xq: torch.Tensor,
243
+ xk: torch.Tensor,
244
+ freqs_cis: torch.Tensor,
245
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
246
+ """
247
+ Apply rotary embeddings to input tensors using the given frequency
248
+ tensor.
249
+
250
+ This function applies rotary embeddings to the given query 'xq' and
251
+ key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
252
+ input tensors are reshaped as complex numbers, and the frequency tensor
253
+ is reshaped for broadcasting compatibility. The resulting tensors
254
+ contain rotary embeddings and are returned as real tensors.
255
+
256
+ Args:
257
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
258
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
259
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
260
+ exponentials.
261
+
262
+ Returns:
263
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
264
+ and key tensor with rotary embeddings.
265
+ """
266
+ with torch.cuda.amp.autocast(enabled=False):
267
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
268
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
269
+ freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_)
270
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
271
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
272
+ return xq_out.type_as(xq), xk_out.type_as(xk)
273
+
274
+ # copied from huggingface modeling_llama.py
275
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
276
+
277
+ def _get_unpad_data(attention_mask):
278
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
279
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
280
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
281
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
282
+ return (
283
+ indices,
284
+ cu_seqlens,
285
+ max_seqlen_in_batch,
286
+ )
287
+
288
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
289
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
290
+
291
+ key_layer = index_first_axis(
292
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
293
+ )
294
+ value_layer = index_first_axis(
295
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
296
+ )
297
+ if query_length == kv_seq_len:
298
+ query_layer = index_first_axis(
299
+ query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim), indices_k
300
+ )
301
+ cu_seqlens_q = cu_seqlens_k
302
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
303
+ indices_q = indices_k
304
+ elif query_length == 1:
305
+ max_seqlen_in_batch_q = 1
306
+ cu_seqlens_q = torch.arange(
307
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
308
+ ) # There is a memcpy here, that is very bad.
309
+ indices_q = cu_seqlens_q[:-1]
310
+ query_layer = query_layer.squeeze(1)
311
+ else:
312
+ # The -q_len: slice assumes left padding.
313
+ attention_mask = attention_mask[:, -query_length:]
314
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
315
+
316
+ return (
317
+ query_layer,
318
+ key_layer,
319
+ value_layer,
320
+ indices_q,
321
+ (cu_seqlens_q, cu_seqlens_k),
322
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
323
+ )
324
+
325
+ def forward(
326
+ self,
327
+ x: torch.Tensor, x_mask: torch.Tensor,
328
+ freqs_cis: torch.Tensor,
329
+ y: torch.Tensor, y_mask: torch.Tensor,
330
+ ) -> torch.Tensor:
331
+ """
332
+ Forward pass of the attention module.
333
+
334
+ Args:
335
+ x (torch.Tensor): Input tensor.
336
+ freqs_cis (torch.Tensor): Precomputed frequency tensor.
337
+
338
+ Returns:
339
+ torch.Tensor: Output tensor after attention.
340
+
341
+ """
342
+ bsz, seqlen, _ = x.shape
343
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
344
+ dtype = xq.dtype
345
+
346
+ xq = self.q_norm(xq)
347
+ xk = self.k_norm(xk)
348
+
349
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
350
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
351
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
352
+
353
+ xq, xk = Attention.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
354
+ xq, xk = xq.to(dtype), xk.to(dtype)
355
+
356
+ if dtype in [torch.float16, torch.bfloat16]:
357
+ # begin var_len flash attn
358
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
359
+ xq, xk, xv, x_mask, seqlen
360
+ )
361
+
362
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
363
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
364
+
365
+ if self.proportional_attn:
366
+ softmax_scale = math.sqrt(math.log(seqlen, self.base_seqlen) / self.head_dim)
367
+ else:
368
+ softmax_scale = math.sqrt(1 / self.head_dim)
369
+ attn_output_unpad = flash_attn_varlen_func(
370
+ query_states,
371
+ key_states,
372
+ value_states,
373
+ cu_seqlens_q=cu_seqlens_q,
374
+ cu_seqlens_k=cu_seqlens_k,
375
+ max_seqlen_q=max_seqlen_in_batch_q,
376
+ max_seqlen_k=max_seqlen_in_batch_k,
377
+ dropout_p=0.,
378
+ causal=False,
379
+ softmax_scale=softmax_scale
380
+ )
381
+ output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
382
+ # end var_len_flash_attn
383
+
384
+ else:
385
+ output = F.scaled_dot_product_attention(
386
+ xq.permute(0, 2, 1, 3),
387
+ xk.permute(0, 2, 1, 3),
388
+ xv.permute(0, 2, 1, 3),
389
+ attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
390
+ ).permute(0, 2, 1, 3).to(dtype)
391
+
392
+ if hasattr(self, "wk_y"):
393
+ yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_local_kv_heads, self.head_dim)
394
+ yv = self.wv_y(y).view(bsz, -1, self.n_local_kv_heads, self.head_dim)
395
+ n_rep = self.n_local_heads // self.n_local_kv_heads
396
+ if n_rep >= 1:
397
+ yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
398
+ yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
399
+ output_y = F.scaled_dot_product_attention(
400
+ xq.permute(0, 2, 1, 3),
401
+ yk.permute(0, 2, 1, 3),
402
+ yv.permute(0, 2, 1, 3),
403
+ y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seqlen, -1)
404
+ ).permute(0, 2, 1, 3)
405
+ output_y = output_y * self.gate.tanh().view(1, 1, -1, 1)
406
+ output = output + output_y
407
+
408
+ output = output.flatten(-2)
409
+
410
+ return self.wo(output)
411
+
412
+
413
+ class FeedForward(nn.Module):
414
+ def __init__(
415
+ self,
416
+ dim: int,
417
+ hidden_dim: int,
418
+ multiple_of: int,
419
+ ffn_dim_multiplier: Optional[float],
420
+ ):
421
+ """
422
+ Initialize the FeedForward module.
423
+
424
+ Args:
425
+ dim (int): Input dimension.
426
+ hidden_dim (int): Hidden dimension of the feedforward layer.
427
+ multiple_of (int): Value to ensure hidden dimension is a multiple
428
+ of this value.
429
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden
430
+ dimension. Defaults to None.
431
+
432
+ Attributes:
433
+ w1 (ColumnParallelLinear): Linear transformation for the first
434
+ layer.
435
+ w2 (RowParallelLinear): Linear transformation for the second layer.
436
+ w3 (ColumnParallelLinear): Linear transformation for the third
437
+ layer.
438
+
439
+ """
440
+ super().__init__()
441
+ hidden_dim = int(2 * hidden_dim / 3)
442
+ # custom dim factor multiplier
443
+ if ffn_dim_multiplier is not None:
444
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
445
+ hidden_dim = multiple_of * (
446
+ (hidden_dim + multiple_of - 1) // multiple_of
447
+ )
448
+
449
+ self.w1 = ColumnParallelLinear(
450
+ dim, hidden_dim, bias=False, gather_output=False,
451
+ init_method=nn.init.xavier_uniform_,
452
+ )
453
+ self.w2 = RowParallelLinear(
454
+ hidden_dim, dim, bias=False, input_is_parallel=True,
455
+ init_method=nn.init.xavier_uniform_,
456
+ )
457
+ self.w3 = ColumnParallelLinear(
458
+ dim, hidden_dim, bias=False, gather_output=False,
459
+ init_method=nn.init.xavier_uniform_,
460
+ )
461
+
462
+ # @torch.compile
463
+ def _forward_silu_gating(self, x1, x3):
464
+ return F.silu(x1) * x3
465
+
466
+ def forward(self, x):
467
+ return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
468
+
469
+
470
+ class TransformerBlock(nn.Module):
471
+ def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int,
472
+ multiple_of: int, ffn_dim_multiplier: float, norm_eps: float,
473
+ qk_norm: bool, y_dim: int) -> None:
474
+ """
475
+ Initialize a TransformerBlock.
476
+
477
+ Args:
478
+ layer_id (int): Identifier for the layer.
479
+ dim (int): Embedding dimension of the input features.
480
+ n_heads (int): Number of attention heads.
481
+ n_kv_heads (Optional[int]): Number of attention heads in key and
482
+ value features (if using GQA), or set to None for the same as
483
+ query.
484
+ multiple_of (int):
485
+ ffn_dim_multiplier (float):
486
+ norm_eps (float):
487
+
488
+ Attributes:
489
+ n_heads (int): Number of attention heads.
490
+ dim (int): Dimension size of the model.
491
+ head_dim (int): Dimension size of each attention head.
492
+ attention (Attention): Attention module.
493
+ feed_forward (FeedForward): FeedForward module.
494
+ layer_id (int): Identifier for the layer.
495
+ attention_norm (RMSNorm): Layer normalization for attention output.
496
+ ffn_norm (RMSNorm): Layer normalization for feedforward output.
497
+
498
+ """
499
+ super().__init__()
500
+ self.dim = dim
501
+ self.head_dim = dim // n_heads
502
+ self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim)
503
+ self.feed_forward = FeedForward(
504
+ dim=dim, hidden_dim=4 * dim, multiple_of=multiple_of,
505
+ ffn_dim_multiplier=ffn_dim_multiplier,
506
+ )
507
+ self.layer_id = layer_id
508
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
509
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
510
+
511
+ self.adaLN_modulation = nn.Sequential(
512
+ nn.SiLU(),
513
+ ColumnParallelLinear(
514
+ min(dim, 1024), 6 * dim, bias=True, gather_output=True,
515
+ init_method=nn.init.zeros_,
516
+ ),
517
+ )
518
+
519
+ self.attention_y_norm = RMSNorm(y_dim, eps=norm_eps)
520
+
521
+ def forward(
522
+ self,
523
+ x: torch.Tensor,
524
+ x_mask: torch.Tensor,
525
+ y: torch.Tensor,
526
+ y_mask: torch.Tensor,
527
+ freqs_cis: torch.Tensor,
528
+ adaln_input: Optional[torch.Tensor] = None,
529
+ ):
530
+ """
531
+ Perform a forward pass through the TransformerBlock.
532
+
533
+ Args:
534
+ x (torch.Tensor): Input tensor.
535
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
536
+ mask (torch.Tensor, optional): Masking tensor for attention.
537
+ Defaults to None.
538
+
539
+ Returns:
540
+ torch.Tensor: Output tensor after applying attention and
541
+ feedforward layers.
542
+
543
+ """
544
+ if adaln_input is not None:
545
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
546
+ self.adaLN_modulation(adaln_input).chunk(6, dim=1)
547
+
548
+ x = x + gate_msa.unsqueeze(1) * self.attention(
549
+ modulate(self.attention_norm(x), shift_msa, scale_msa),
550
+ x_mask,
551
+ freqs_cis,
552
+ self.attention_y_norm(y), y_mask,
553
+ )
554
+ x = x + gate_mlp.unsqueeze(1) * self.feed_forward(
555
+ modulate(self.ffn_norm(x), shift_mlp, scale_mlp),
556
+ )
557
+
558
+ else:
559
+ x = x + self.attention(
560
+ self.attention_norm(x), x_mask, freqs_cis, self.attention_y_norm(y), y_mask,
561
+ )
562
+ x = x + self.feed_forward(self.ffn_norm(x))
563
+
564
+ return x
565
+
566
+ class ParallelFinalLayer(nn.Module):
567
+ """
568
+ The final layer of DiT.
569
+ """
570
+ def __init__(self, hidden_size, patch_size, out_channels):
571
+ super().__init__()
572
+ self.norm_final = nn.LayerNorm(
573
+ hidden_size, elementwise_affine=False, eps=1e-6,
574
+ )
575
+ self.linear = ColumnParallelLinear(
576
+ hidden_size, patch_size * patch_size * out_channels, bias=True,
577
+ init_method=nn.init.zeros_, gather_output=True,
578
+ )
579
+ self.adaLN_modulation = nn.Sequential(
580
+ nn.SiLU(),
581
+ ColumnParallelLinear(
582
+ min(hidden_size, 1024), 2 * hidden_size, bias=True,
583
+ init_method=nn.init.zeros_, gather_output=True,
584
+ ),
585
+ )
586
+
587
+ def forward(self, x, c):
588
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
589
+ x = modulate(self.norm_final(x), shift, scale)
590
+ x = self.linear(x)
591
+ return x
592
+
593
+
594
+ class DiT_Llama(nn.Module):
595
+ """
596
+ Diffusion model with a Transformer backbone.
597
+ """
598
+ def __init__(
599
+ self,
600
+ patch_size: int = 2,
601
+ in_channels: int = 4,
602
+ dim: int = 4096,
603
+ n_layers: int = 32,
604
+ n_heads: int = 32,
605
+ n_kv_heads: Optional[int] = None,
606
+ multiple_of: int = 256,
607
+ ffn_dim_multiplier: Optional[float] = None,
608
+ norm_eps: float = 1e-5,
609
+ learn_sigma: bool = True,
610
+ qk_norm: bool = False,
611
+ cap_feat_dim: int = 5120,
612
+ rope_scaling_factor: float = 1.,
613
+ ntk_factor: float=1.
614
+ ) -> None:
615
+ super().__init__()
616
+ self.learn_sigma = learn_sigma
617
+ self.in_channels = in_channels
618
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
619
+ self.patch_size = patch_size
620
+
621
+ self.x_embedder = ColumnParallelLinear(
622
+ in_features=patch_size * patch_size * in_channels,
623
+ out_features=dim,
624
+ bias=True,
625
+ gather_output=True,
626
+ init_method=nn.init.xavier_uniform_,
627
+ )
628
+ nn.init.constant_(self.x_embedder.bias, 0.)
629
+
630
+ self.t_embedder = ParallelTimestepEmbedder(min(dim, 1024))
631
+ self.cap_embedder = nn.Sequential(
632
+ nn.LayerNorm(cap_feat_dim),
633
+ ColumnParallelLinear(
634
+ cap_feat_dim, min(dim, 1024), bias=True, gather_output=True,
635
+ init_method=nn.init.zeros_
636
+ ),
637
+ )
638
+
639
+ self.layers = nn.ModuleList([
640
+ TransformerBlock(layer_id, dim, n_heads, n_kv_heads, multiple_of,
641
+ ffn_dim_multiplier, norm_eps, qk_norm, cap_feat_dim)
642
+ for layer_id in range(n_layers)
643
+ ])
644
+ self.final_layer = ParallelFinalLayer(dim, patch_size, self.out_channels)
645
+
646
+ self.freqs_cis = DiT_Llama.precompute_freqs_cis(
647
+ dim // n_heads, 40000, rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
648
+ )
649
+ self.dim = dim
650
+ self.n_heads = n_heads
651
+ self.rope_scaling_factor = rope_scaling_factor
652
+ self.ntk_factor = ntk_factor
653
+ self.eol_token = nn.Parameter(torch.empty(dim))
654
+ self.pad_token = nn.Parameter(torch.empty(dim))
655
+ nn.init.normal_(self.eol_token, std=0.02)
656
+ nn.init.normal_(self.pad_token, std=0.02)
657
+
658
+ def unpatchify(self, x: torch.Tensor, img_size: List[Tuple[int, int]], return_tensor=False) -> List[torch.Tensor]:
659
+ """
660
+ x: (N, T, patch_size**2 * C)
661
+ imgs: (N, H, W, C)
662
+ """
663
+ pH = pW = self.patch_size
664
+ if return_tensor:
665
+ H, W = img_size[0]
666
+ B = x.size(0)
667
+ L = (H // pH) * (W // pW + 1) # one additional for eol
668
+ x = x[:, :L].view(B, H // pH, W // pW + 1, pH, pW, self.out_channels)
669
+ x = x[:, :, :-1]
670
+ x = x.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
671
+ return x
672
+ else:
673
+ imgs = []
674
+ for i in range(x.size(0)):
675
+ H, W = img_size[i]
676
+ L = (H // pH) * (W // pW + 1)
677
+ imgs.append(x[i][:L].view(
678
+ H // pH, W // pW + 1, pH, pW, self.out_channels
679
+ )[:, :-1, :, :, :].permute(4, 0, 2, 1, 3).flatten(3, 4).flatten(1, 2))
680
+ return imgs
681
+
682
+ def patchify_and_embed(
683
+ self,
684
+ x: List[torch.Tensor] | torch.Tensor
685
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]]]:
686
+ if isinstance(x, torch.Tensor):
687
+ pH = pW = self.patch_size
688
+ B, C, H, W = x.size()
689
+ x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3)
690
+ x = self.x_embedder(x)
691
+ x = torch.cat([
692
+ x,
693
+ self.eol_token.view(1, 1, 1, -1).expand(B, H // pH, 1, -1),
694
+ ], dim=2)
695
+ x = x.flatten(1, 2)
696
+
697
+ mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
698
+ return x, mask, [(H, W)] * B
699
+ else:
700
+ pH = pW = self.patch_size
701
+ x_embed = []
702
+ img_size = []
703
+ l_effective_seq_len = []
704
+
705
+ for img in x:
706
+ C, H, W = img.size()
707
+ img_size.append((H, W))
708
+ img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 0, 2, 4).flatten(2)
709
+ img = self.x_embedder(img)
710
+ img = torch.cat([
711
+ img,
712
+ self.eol_token.view(1, 1, -1).expand(H // pH, 1, -1),
713
+ ], dim=1)
714
+ img = img.flatten(0, 1)
715
+ l_effective_seq_len.append(len(img))
716
+ x_embed.append(img)
717
+
718
+ max_seq_len = max(l_effective_seq_len)
719
+ mask = torch.zeros(len(x), max_seq_len, dtype=torch.int32, device=x[0].device)
720
+ padded_x_embed = []
721
+ for i, (item_embed, item_seq_len) in enumerate(zip(x_embed, l_effective_seq_len)):
722
+ item_embed = torch.cat([
723
+ item_embed,
724
+ self.pad_token.view(1, -1).expand(max_seq_len - item_seq_len, -1),
725
+ ], dim=0)
726
+ padded_x_embed.append(item_embed)
727
+ mask[i][:item_seq_len] = 1
728
+
729
+ x_embed = torch.stack(padded_x_embed, dim=0)
730
+ return x_embed, mask, img_size
731
+
732
+ def forward(self, x, t, cap_feats, cap_mask):
733
+ """
734
+ Forward pass of DiT.
735
+ t: (N,) tensor of diffusion timesteps
736
+ y: (N,) tensor of class labels
737
+ """
738
+ x_is_tensor = isinstance(x, torch.Tensor)
739
+ x, mask, img_size = self.patchify_and_embed(x)
740
+ self.freqs_cis = self.freqs_cis.to(x.device)
741
+
742
+ t = self.t_embedder(t) # (N, D)
743
+ cap_mask_float = cap_mask.float().unsqueeze(-1)
744
+ cap_feats_pool = (cap_feats * cap_mask_float).sum(dim=1) / cap_mask_float.sum(dim=1)
745
+ cap_feats_pool = cap_feats_pool.to(cap_feats)
746
+ cap_emb = self.cap_embedder(cap_feats_pool)
747
+ adaln_input = t + cap_emb
748
+
749
+ cap_mask = cap_mask.bool()
750
+ for layer in self.layers:
751
+ x = layer(
752
+ x, mask, cap_feats, cap_mask, self.freqs_cis[:x.size(1)],
753
+ adaln_input=adaln_input
754
+ )
755
+
756
+ x = self.final_layer(x, adaln_input)
757
+ x = self.unpatchify(x, img_size, return_tensor=x_is_tensor)
758
+ if self.learn_sigma:
759
+ if x_is_tensor:
760
+ x, _ = x.chunk(2, dim=1)
761
+ else:
762
+ x = [_.chunk(2, dim=0)[0] for _ in x]
763
+ return x
764
+
765
+ def forward_with_cfg(
766
+ self,
767
+ x,
768
+ t,
769
+ cap_feats,
770
+ cap_mask,
771
+ cfg_scale,
772
+ rope_scaling_factor=None,
773
+ ntk_factor=None,
774
+ base_seqlen: Optional[int] = None,
775
+ proportional_attn: bool = False
776
+ ):
777
+ """
778
+ Forward pass of DiT, but also batches the unconditional forward pass
779
+ for classifier-free guidance.
780
+ """
781
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
782
+
783
+ if rope_scaling_factor is not None or ntk_factor is not None:
784
+ rope_scaling_factor = rope_scaling_factor if rope_scaling_factor is not None else self.rope_scaling_factor
785
+ ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
786
+ if rope_scaling_factor != self.rope_scaling_factor or ntk_factor != self.ntk_factor:
787
+ print(f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}", flush=True)
788
+ self.freqs_cis = DiT_Llama.precompute_freqs_cis(
789
+ self.dim // self.n_heads, 40000,
790
+ rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
791
+ )
792
+ self.rope_scaling_factor = rope_scaling_factor
793
+ self.ntk_factor = ntk_factor
794
+
795
+ if proportional_attn:
796
+ assert base_seqlen is not None
797
+ for layer in self.layers:
798
+ layer.attention.base_seqlen = base_seqlen
799
+ layer.attention.proportional_attn = proportional_attn
800
+ else:
801
+ for layer in self.layers:
802
+ layer.attention.base_seqlen = None
803
+ layer.attention.proportional_attn = proportional_attn
804
+
805
+ half = x[: len(x) // 2]
806
+ combined = torch.cat([half, half], dim=0)
807
+ model_out = self(combined, t, cap_feats, cap_mask)
808
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
809
+ # three channels by default. The standard approach to cfg applies it to all channels.
810
+ # This can be done by uncommenting the following line and commenting-out the line following that.
811
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
812
+ eps, rest = model_out[:, :3], model_out[:, 3:]
813
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
814
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
815
+ eps = torch.cat([half_eps, half_eps], dim=0)
816
+ return torch.cat([eps, rest], dim=1)
817
+
818
+ @staticmethod
819
+ def precompute_freqs_cis(
820
+ dim: int,
821
+ end: int,
822
+ theta: float = 10000.0,
823
+ rope_scaling_factor: float = 1.0,
824
+ ntk_factor: float = 1.0
825
+ ):
826
+ """
827
+ Precompute the frequency tensor for complex exponentials (cis) with
828
+ given dimensions.
829
+
830
+ This function calculates a frequency tensor with complex exponentials
831
+ using the given dimension 'dim' and the end index 'end'. The 'theta'
832
+ parameter scales the frequencies. The returned tensor contains complex
833
+ values in complex64 data type.
834
+
835
+ Args:
836
+ dim (int): Dimension of the frequency tensor.
837
+ end (int): End index for precomputing frequencies.
838
+ theta (float, optional): Scaling factor for frequency computation.
839
+ Defaults to 10000.0.
840
+
841
+ Returns:
842
+ torch.Tensor: Precomputed frequency tensor with complex
843
+ exponentials.
844
+ """
845
+
846
+ theta = theta * ntk_factor
847
+
848
+ print(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}")
849
+ freqs = 1.0 / (theta ** (
850
+ torch.arange(0, dim, 2)[: (dim // 2)].float().cuda() / dim
851
+ ))
852
+ t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
853
+ t = t / rope_scaling_factor
854
+ freqs = torch.outer(t, freqs).float() # type: ignore
855
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
856
+ return freqs_cis
857
+
858
+ def parameter_count(self) -> int:
859
+ tensor_parallel_module_list = (
860
+ ColumnParallelLinear, RowParallelLinear, ParallelEmbedding,
861
+ )
862
+ total_params = 0
863
+
864
+ def _recursive_count_params(module):
865
+ nonlocal total_params
866
+ is_tp_module = isinstance(module, tensor_parallel_module_list)
867
+ for param in module.parameters(recurse=False):
868
+ total_params += param.numel() * (
869
+ fs_init.get_model_parallel_world_size()
870
+ if is_tp_module else 1
871
+ )
872
+ for submodule in module.children():
873
+ _recursive_count_params(submodule)
874
+
875
+ _recursive_count_params(self)
876
+ return total_params
877
+
878
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
879
+ return list(self.layers)
880
+
881
+
882
+ #############################################################################
883
+ # DiT Configs #
884
+ #############################################################################
885
+
886
+ def DiT_Llama_2B_patch2(**kwargs):
887
+ return DiT_Llama(
888
+ patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs
889
+ )
890
+
891
+ def DiT_Llama_5B_patch2(**kwargs):
892
+ return DiT_Llama(
893
+ patch_size=2, dim=3072, n_layers=32, n_heads=32, **kwargs
894
+ )
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ diffusers
3
+ huggingface_hub
4
+ gradio
5
+ torch
6
+ # torch==2.2.2+cu121
7
+ fairscale
8
+ numpy
9
+ pillow
10
+ torchdiffeq
11
+ click
12
+ git+https://github.com/Alpha-VLLM/Lumina-T2X