PommesPeter commited on
Commit
be34a3d
1 Parent(s): 35f1cc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -193
app.py CHANGED
@@ -1,16 +1,25 @@
1
  import os
2
  import subprocess
3
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
 
 
 
 
4
 
5
  from huggingface_hub import snapshot_download
 
6
  os.makedirs("/home/user/app/checkpoints", exist_ok=True)
7
- snapshot_download(repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints")
 
 
8
 
9
  import argparse
10
  import builtins
11
  import json
12
  import random
13
  import socket
 
14
  import spaces
15
  import traceback
16
 
@@ -39,14 +48,14 @@ description = """
39
  Demo current model: `Lumina-Next-T2I`
40
  """
41
 
42
- hf_token = os.environ['HF_TOKEN']
 
43
 
44
  class ModelFailure:
45
  pass
46
 
47
 
48
  # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
49
- @spaces.GPU
50
  def encode_prompt(
51
  prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True
52
  ):
@@ -83,8 +92,7 @@ def encode_prompt(
83
  return prompt_embeds, prompt_masks
84
 
85
 
86
- @spaces.GPU(duration=200)
87
- def load_model(args, master_port, rank):
88
  # import here to avoid huggingface Tokenizer parallelism warnings
89
  from diffusers.models import AutoencoderKL
90
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -100,31 +108,21 @@ def load_model(args, master_port, rank):
100
  # Override the built-in print with the new version
101
  builtins.print = print
102
 
103
- os.environ["MASTER_PORT"] = str(master_port)
104
- os.environ["MASTER_ADDR"] = "127.0.0.1"
105
- os.environ["RANK"] = str(rank)
106
- os.environ["WORLD_SIZE"] = str(args.num_gpus)
107
-
108
- dist.init_process_group("nccl")
109
- # set up fairscale environment because some methods of the Lumina model need it,
110
- # though for single-GPU inference fairscale actually has no effect
111
- fs_init.initialize_model_parallel(args.num_gpus)
112
- torch.cuda.set_device(rank)
113
-
114
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
115
- if dist.get_rank() == 0:
116
- print("Loaded model arguments:", json.dumps(train_args.__dict__, indent=2))
117
-
118
- if dist.get_rank() == 0:
119
- print(f"Creating lm: Gemma-2B")
120
 
121
  dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
122
  args.precision
123
  ]
124
 
 
125
  text_encoder = (
126
  AutoModelForCausalLM.from_pretrained(
127
- "google/gemma-2b", torch_dtype=dtype, device_map="cuda", token=hf_token,
 
 
 
 
128
  )
129
  .get_decoder()
130
  .eval()
@@ -134,24 +132,27 @@ def load_model(args, master_port, rank):
134
  raise NotImplementedError("Inference with >1 GPUs not yet supported")
135
 
136
  tokenizer = AutoTokenizer.from_pretrained(
137
- "google/gemma-2b", add_bos_token=True, add_eos_token=True, token=hf_token,
 
 
 
138
  )
139
  tokenizer.padding_side = "right"
140
 
141
- if dist.get_rank() == 0:
142
- print(f"Creating vae: sdxl-vae")
143
- vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae",
144
  torch_dtype=torch.float32,
145
- ).cuda()
146
 
147
- if dist.get_rank() == 0:
148
- print(f"Creating DiT: Next-DiT")
149
  # latent_size = train_args.image_size // 8
150
  model = models.__dict__["NextDiT_2B_patch2"](
151
  qk_norm=train_args.qk_norm,
152
  cap_feat_dim=cap_feat_dim,
153
  )
154
- model.eval().to("cuda", dtype=dtype)
 
155
 
156
  assert train_args.model_parallel_size == args.num_gpus
157
  if args.ema:
@@ -169,137 +170,141 @@ def load_model(args, master_port, rank):
169
  return text_encoder, tokenizer, vae, model
170
 
171
 
172
- @spaces.GPU
173
  @torch.no_grad()
174
- def model_main(args, master_port, rank, request_queue, response_queue):
175
  dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
176
  args.precision
177
  ]
178
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
179
- text_encoder, tokenizer, vae, model = load_model(args, master_port, rank)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  with torch.autocast("cuda", dtype):
182
  # barrier.wait()
183
- while True:
184
- (
185
- cap,
186
- resolution,
187
- num_sampling_steps,
188
- cfg_scale,
189
- solver,
190
- t_shift,
191
- seed,
192
- ntk_scaling,
193
- proportional_attn,
194
- ) = request_queue.get()
195
-
196
- print(
197
- "> params:",
198
- cap,
199
- resolution,
200
- num_sampling_steps,
201
- cfg_scale,
202
- solver,
203
- t_shift,
204
- seed,
205
- ntk_scaling,
206
- proportional_attn,
 
 
 
 
 
 
 
 
207
  )
208
- try:
209
- # begin sampler
210
- transport = create_transport(
211
- args.path_type,
212
- args.prediction,
213
- args.loss_weight,
214
- args.train_eps,
215
- args.sample_eps,
216
  )
217
- sampler = Sampler(transport)
218
- if args.sampler_mode == "ODE":
219
- if args.likelihood:
220
- # assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" # todo
221
- sample_fn = sampler.sample_ode_likelihood(
222
- sampling_method=solver,
223
- num_steps=num_sampling_steps,
224
- atol=args.atol,
225
- rtol=args.rtol,
226
- )
227
- else:
228
- sample_fn = sampler.sample_ode(
229
- sampling_method=solver,
230
- num_steps=num_sampling_steps,
231
- atol=args.atol,
232
- rtol=args.rtol,
233
- reverse=args.reverse,
234
- time_shifting_factor=t_shift,
235
- )
236
- elif args.sampler_mode == "SDE":
237
- sample_fn = sampler.sample_sde(
238
- sampling_method=solver,
239
- diffusion_form=args.diffusion_form,
240
- diffusion_norm=args.diffusion_norm,
241
- last_step=args.last_step,
242
- last_step_size=args.last_step_size,
243
- num_steps=num_sampling_steps,
244
- )
245
- # end sampler
246
-
247
- resolution = resolution.split(" ")[-1]
248
- w, h = resolution.split("x")
249
- w, h = int(w), int(h)
250
- latent_w, latent_h = w // 8, h // 8
251
- if int(seed) != 0:
252
- torch.random.manual_seed(int(seed))
253
- z = torch.randn([1, 4, latent_h, latent_w], device="cuda").to(dtype)
254
- z = z.repeat(2, 1, 1, 1)
255
-
256
- with torch.no_grad():
257
- cap_feats, cap_mask = encode_prompt(
258
- [cap] + [""], text_encoder, tokenizer, 0.0
259
- )
260
- cap_mask = cap_mask.to(cap_feats.device)
261
-
262
- train_res = 1024
263
- res_cat = (w * h) ** 0.5
264
- print(f"res_cat: {res_cat}")
265
- max_seq_len = (res_cat // 16) ** 2 + (res_cat // 16) * 2
266
- print(f"max_seq_len: {max_seq_len}")
267
-
268
- rope_scaling_factor = 1.0
269
- ntk_factor = max_seq_len / (train_res // 16) ** 2
270
- print(f"ntk_factor: {ntk_factor}")
271
-
272
- model_kwargs = dict(
273
- cap_feats=cap_feats,
274
- cap_mask=cap_mask,
275
- cfg_scale=cfg_scale,
276
- rope_scaling_factor=rope_scaling_factor,
277
- ntk_factor=ntk_factor,
278
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
- if dist.get_rank() == 0:
281
- print(f"caption: {cap}")
282
- print(f"num_sampling_steps: {num_sampling_steps}")
283
- print(f"cfg_scale: {cfg_scale}")
284
 
285
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
286
- print("> [debug] start sample")
287
- samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
288
- samples = samples[:1]
289
 
290
- factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
291
- print(f"vae factor: {factor}")
292
- samples = vae.decode(samples / factor).sample
293
- samples = (samples + 1.0) / 2.0
294
- samples.clamp_(0.0, 1.0)
295
- img = to_pil_image(samples[0].float())
296
 
297
- if response_queue is not None:
298
- response_queue.put(img)
299
 
300
- except Exception:
301
- print(traceback.format_exc())
302
- response_queue.put(ModelFailure())
 
303
 
304
 
305
  def none_or_str(value):
@@ -412,7 +417,6 @@ def find_free_port() -> int:
412
  return port
413
 
414
 
415
- @spaces.GPU
416
  def main():
417
  parser = argparse.ArgumentParser()
418
  mode = "ODE"
@@ -423,13 +427,7 @@ def main():
423
  parser.add_argument("--precision", default="bf16", choices=["bf16", "fp32"])
424
 
425
  parse_transport_args(parser)
426
- if mode == "ODE":
427
- parse_ode_args(parser)
428
- # Further processing for ODE
429
- elif mode == "SDE":
430
- parse_sde_args(parser)
431
- # Further processing for SDE
432
-
433
  args = parser.parse_known_args()[0]
434
 
435
  if args.num_gpus != 1:
@@ -437,24 +435,7 @@ def main():
437
 
438
  args.sampler_mode = mode
439
 
440
- master_port = find_free_port()
441
-
442
- request_queues = []
443
- response_queue = Queue()
444
- # mp_barrier = mp.Barrier(args.num_gpus + 1)
445
- # barrier = Barrier(args.num_gpus + 1)
446
- for i in range(args.num_gpus):
447
- request_queues.append(Queue())
448
- generation_kwargs = dict(
449
- args=args,
450
- master_port=master_port,
451
- rank=i,
452
- request_queue=request_queues[i],
453
- response_queue=response_queue if i == 0 else None,
454
- )
455
- model_main(**generation_kwargs)
456
- # thread = Thread(target=model_main, kwargs=generation_kwargs)
457
- # thread.start()
458
 
459
  with gr.Blocks() as demo:
460
  with gr.Row():
@@ -482,6 +463,7 @@ def main():
482
  minimum=1,
483
  maximum=70,
484
  value=30,
 
485
  interactive=True,
486
  label="Sampling steps",
487
  )
@@ -537,12 +519,10 @@ def main():
537
  # ntk_scaling, proportional_attn
538
  # ])
539
  with gr.Column():
540
- # default_img = Image.open("./image.png")
541
  output_img = gr.Image(
542
  label="Generated image",
543
  interactive=False,
544
  format="png",
545
- # value=default_img,
546
  )
547
 
548
  with gr.Row():
@@ -557,35 +537,60 @@ def main():
557
  ["味噌ラーメン, 最高品質の浮世絵、江戸時代。"],
558
  ["東京タワー、最高品質の浮世絵、江戸時代。"],
559
  ["Astronaut on Mars During sunset"],
560
- ["Tour de Tokyo, estampes ukiyo-e de la plus haute qualité, période Edo"],
 
 
561
  ["🐔 playing 🏀"],
562
  ["☃️ with 🌹 in the ❄️"],
563
  ["🐶 wearing 😎 flying on 🌈 "],
564
  ["A small 🍎 and 🍊 with 😁 emoji in the Sahara desert"],
565
  ["Токийская башня, лучшие укиё-э, период Эдо"],
566
  ["Tokio-Turm, hochwertigste Ukiyo-e, Edo-Zeit"],
567
- ["A scared cute rabbit in Happy Tree Friends style and punk vibe."], # noqa
 
 
568
  ["A humanoid eagle soldier of the First World War."], # noqa
569
- ["A cute Christmas mockup on an old wooden industrial desk table with Christmas decorations and bokeh lights in the background."],
570
- ["A front view of a romantic flower shop in France filled with various blooming flowers including lavenders and roses."],
571
- ["An old man, portrayed as a retro superhero, stands in the streets of New York City at night"],
572
- ["many trees are surrounded by a lake in autumn colors, in the style of nature-inspired imagery, havencore, brightly colored, dark white and dark orange, bright primary colors, environmental activism, forestpunk --ar 64:51"],
573
- ["A fluffy mouse holding a watermelon, in a magical and colorful setting, illustrated in the style of Hayao Miyazaki anime by Studio Ghibli."],
574
- ["Inka warrior with a war make up, medium shot, natural light, Award winning wildlife photography, hyperrealistic, 8k resolution, --ar 9:16"],
575
- ["Character of lion in style of saiyan, mafia, gangsta, citylights background, Hyper detailed, hyper realistic, unreal engine ue5, cgi 3d, cinematic shot, 8k"],
576
- ["In the sky above, a giant, whimsical cloud shaped like the 😊 emoji casts a soft, golden light over the scene"],
577
- ["Cyberpunk eagle, neon ambiance, abstract black oil, gear mecha, detailed acrylic, grunge, intricate complexity, rendered in unreal engine 5, photorealistic, 8k"],
578
- ["close-up photo of a beautiful red rose breaking through a cube made of ice , splintered cracked ice surface, frosted colors, blood dripping from rose, melting ice, Valentine’s Day vibes, cinematic, sharp focus, intricate, cinematic, dramatic light"],
579
- ["3D cartoon Fox Head with Human Body, Wearing Iridescent Holographic Liquid Texture & Translucent Material Sun Protective Shirt, Boss Feel, Nike or Addidas Sun Protective Shirt, WitchPunk, Y2K Style, Green and blue, Blue, Metallic Feel, Strong Reflection, plain background, no background, pure single color background, Digital Fashion, Surreal Futurism, Supreme Kong NFT Artwork Style, disney style, headshot photography for portrait studio shoot, fashion editorial aesthetic, high resolution in the style of HAPE PRIME NFT, NFT 3D IP Feel, Bored Ape Yacht Club NFT project Feel, high detail, fine luster, 3D render, oc render, best quality, 8K, bright, front lighting, Face Shot, fine luster, ultra detailed"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  ],
581
  [cap],
582
  label="Examples",
583
  )
584
 
585
- def on_submit(*args):
586
- for q in request_queues:
587
- q.put(args)
588
- result = response_queue.get()
589
  if isinstance(result, ModelFailure):
590
  raise RuntimeError("Model failed to generate the image.")
591
  return result
@@ -606,10 +611,8 @@ def main():
606
  [output_img],
607
  )
608
 
609
- # barrier.wait()
610
  demo.queue(max_size=20).launch()
611
 
612
 
613
  if __name__ == "__main__":
614
- # mp.set_start_method("spawn")
615
  main()
 
1
  import os
2
  import subprocess
3
+
4
+ subprocess.run(
5
+ "pip install flash-attn --no-build-isolation",
6
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
7
+ shell=True,
8
+ )
9
 
10
  from huggingface_hub import snapshot_download
11
+
12
  os.makedirs("/home/user/app/checkpoints", exist_ok=True)
13
+ snapshot_download(
14
+ repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints"
15
+ )
16
 
17
  import argparse
18
  import builtins
19
  import json
20
  import random
21
  import socket
22
+
23
  import spaces
24
  import traceback
25
 
 
48
  Demo current model: `Lumina-Next-T2I`
49
  """
50
 
51
+ hf_token = os.environ["HF_TOKEN"]
52
+
53
 
54
  class ModelFailure:
55
  pass
56
 
57
 
58
  # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
 
59
  def encode_prompt(
60
  prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True
61
  ):
 
92
  return prompt_embeds, prompt_masks
93
 
94
 
95
+ def load_models(args, master_port, rank):
 
96
  # import here to avoid huggingface Tokenizer parallelism warnings
97
  from diffusers.models import AutoencoderKL
98
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
108
  # Override the built-in print with the new version
109
  builtins.print = print
110
 
 
 
 
 
 
 
 
 
 
 
 
111
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
112
+ print("Loaded model arguments:", json.dumps(train_args.__dict__, indent=2))
 
 
 
 
113
 
114
  dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
115
  args.precision
116
  ]
117
 
118
+ print(f"Creating lm: Gemma-2B")
119
  text_encoder = (
120
  AutoModelForCausalLM.from_pretrained(
121
+ "google/gemma-2b",
122
+ torch_dtype=dtype,
123
+ device_map="cpu",
124
+ # device_map="cuda",
125
+ token=hf_token,
126
  )
127
  .get_decoder()
128
  .eval()
 
132
  raise NotImplementedError("Inference with >1 GPUs not yet supported")
133
 
134
  tokenizer = AutoTokenizer.from_pretrained(
135
+ "google/gemma-2b",
136
+ add_bos_token=True,
137
+ add_eos_token=True,
138
+ token=hf_token,
139
  )
140
  tokenizer.padding_side = "right"
141
 
142
+ print(f"Creating vae: sdxl-vae")
143
+ vae = AutoencoderKL.from_pretrained(
144
+ "stabilityai/sdxl-vae",
145
  torch_dtype=torch.float32,
146
+ )
147
 
148
+ print(f"Creating DiT: Next-DiT")
 
149
  # latent_size = train_args.image_size // 8
150
  model = models.__dict__["NextDiT_2B_patch2"](
151
  qk_norm=train_args.qk_norm,
152
  cap_feat_dim=cap_feat_dim,
153
  )
154
+ # model.eval().to("cuda", dtype=dtype)
155
+ model.eval()
156
 
157
  assert train_args.model_parallel_size == args.num_gpus
158
  if args.ema:
 
170
  return text_encoder, tokenizer, vae, model
171
 
172
 
 
173
  @torch.no_grad()
174
+ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
175
  dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
176
  args.precision
177
  ]
178
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
179
+
180
+ print(args)
181
+
182
+ os.environ["MASTER_PORT"] = str(60001)
183
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
184
+ os.environ["RANK"] = str(0)
185
+ os.environ["WORLD_SIZE"] = str(args.num_gpus)
186
+
187
+ # dist.init_process_group("nccl")
188
+ # set up fairscale environment because some methods of the Lumina model need it,
189
+ # though for single-GPU inference fairscale actually has no effect
190
+ # fs_init.initialize_model_parallel(args.num_gpus)
191
+ torch.cuda.set_device(0)
192
+
193
+ # loading model to gpu
194
+ text_encoder = text_encoder.cuda()
195
+ vae = vae.cuda()
196
+ model = model.to("cuda", dtype=dtype)
197
 
198
  with torch.autocast("cuda", dtype):
199
  # barrier.wait()
200
+ (
201
+ cap,
202
+ resolution,
203
+ num_sampling_steps,
204
+ cfg_scale,
205
+ solver,
206
+ t_shift,
207
+ seed,
208
+ ntk_scaling,
209
+ proportional_attn,
210
+ ) = infer_args
211
+
212
+ print(
213
+ "> params:",
214
+ cap,
215
+ resolution,
216
+ num_sampling_steps,
217
+ cfg_scale,
218
+ solver,
219
+ t_shift,
220
+ seed,
221
+ ntk_scaling,
222
+ proportional_attn,
223
+ )
224
+ try:
225
+ # begin sampler
226
+ transport = create_transport(
227
+ args.path_type,
228
+ args.prediction,
229
+ args.loss_weight,
230
+ args.train_eps,
231
+ args.sample_eps,
232
  )
233
+ sampler = Sampler(transport)
234
+ if args.likelihood:
235
+ # assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" # todo
236
+ sample_fn = sampler.sample_ode_likelihood(
237
+ sampling_method=solver,
238
+ num_steps=num_sampling_steps,
239
+ atol=args.atol,
240
+ rtol=args.rtol,
241
  )
242
+ else:
243
+ sample_fn = sampler.sample_ode(
244
+ sampling_method=solver,
245
+ num_steps=num_sampling_steps,
246
+ atol=args.atol,
247
+ rtol=args.rtol,
248
+ reverse=args.reverse,
249
+ time_shifting_factor=t_shift,
250
+ )
251
+ # end sampler
252
+
253
+ resolution = resolution.split(" ")[-1]
254
+ w, h = resolution.split("x")
255
+ w, h = int(w), int(h)
256
+ latent_w, latent_h = w // 8, h // 8
257
+ if int(seed) != 0:
258
+ torch.random.manual_seed(int(seed))
259
+ z = torch.randn([1, 4, latent_h, latent_w], device="cuda").to(dtype)
260
+ z = z.repeat(2, 1, 1, 1)
261
+
262
+ with torch.no_grad():
263
+ cap_feats, cap_mask = encode_prompt(
264
+ [cap] + [""], text_encoder, tokenizer, 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  )
266
+ cap_mask = cap_mask.to(cap_feats.device)
267
+
268
+ train_res = 1024
269
+ res_cat = (w * h) ** 0.5
270
+ print(f"res_cat: {res_cat}")
271
+ max_seq_len = (res_cat // 16) ** 2 + (res_cat // 16) * 2
272
+ print(f"max_seq_len: {max_seq_len}")
273
+
274
+ rope_scaling_factor = 1.0
275
+ ntk_factor = max_seq_len / (train_res // 16) ** 2
276
+ print(f"ntk_factor: {ntk_factor}")
277
+
278
+ model_kwargs = dict(
279
+ cap_feats=cap_feats,
280
+ cap_mask=cap_mask,
281
+ cfg_scale=cfg_scale,
282
+ rope_scaling_factor=rope_scaling_factor,
283
+ ntk_factor=ntk_factor,
284
+ )
285
+
286
+ print(f"caption: {cap}")
287
+ print(f"num_sampling_steps: {num_sampling_steps}")
288
+ print(f"cfg_scale: {cfg_scale}")
289
 
290
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
291
+ print("> [debug] start sample")
292
+ samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
293
+ samples = samples[:1]
294
 
295
+ factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
296
+ print(f"vae factor: {factor}")
 
 
297
 
298
+ samples = vae.decode(samples / factor).sample
299
+ samples = (samples + 1.0) / 2.0
300
+ samples.clamp_(0.0, 1.0)
 
 
 
301
 
302
+ img = to_pil_image(samples[0].float())
 
303
 
304
+ return img
305
+ except Exception:
306
+ print(traceback.format_exc())
307
+ return ModelFailure()
308
 
309
 
310
  def none_or_str(value):
 
417
  return port
418
 
419
 
 
420
  def main():
421
  parser = argparse.ArgumentParser()
422
  mode = "ODE"
 
427
  parser.add_argument("--precision", default="bf16", choices=["bf16", "fp32"])
428
 
429
  parse_transport_args(parser)
430
+ parse_ode_args(parser)
 
 
 
 
 
 
431
  args = parser.parse_known_args()[0]
432
 
433
  if args.num_gpus != 1:
 
435
 
436
  args.sampler_mode = mode
437
 
438
+ text_encoder, tokenizer, vae, model = load_models(args, 60001, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
  with gr.Blocks() as demo:
441
  with gr.Row():
 
463
  minimum=1,
464
  maximum=70,
465
  value=30,
466
+ step=1,
467
  interactive=True,
468
  label="Sampling steps",
469
  )
 
519
  # ntk_scaling, proportional_attn
520
  # ])
521
  with gr.Column():
 
522
  output_img = gr.Image(
523
  label="Generated image",
524
  interactive=False,
525
  format="png",
 
526
  )
527
 
528
  with gr.Row():
 
537
  ["味噌ラーメン, 最高品質の浮世絵、江戸時代。"],
538
  ["東京タワー、最高品質の浮世絵、江戸時代。"],
539
  ["Astronaut on Mars During sunset"],
540
+ [
541
+ "Tour de Tokyo, estampes ukiyo-e de la plus haute qualité, période Edo"
542
+ ],
543
  ["🐔 playing 🏀"],
544
  ["☃️ with 🌹 in the ❄️"],
545
  ["🐶 wearing 😎 flying on 🌈 "],
546
  ["A small 🍎 and 🍊 with 😁 emoji in the Sahara desert"],
547
  ["Токийская башня, лучшие укиё-э, период Эдо"],
548
  ["Tokio-Turm, hochwertigste Ukiyo-e, Edo-Zeit"],
549
+ [
550
+ "A scared cute rabbit in Happy Tree Friends style and punk vibe."
551
+ ], # noqa
552
  ["A humanoid eagle soldier of the First World War."], # noqa
553
+ [
554
+ "A cute Christmas mockup on an old wooden industrial desk table with Christmas decorations and bokeh lights in the background."
555
+ ],
556
+ [
557
+ "A front view of a romantic flower shop in France filled with various blooming flowers including lavenders and roses."
558
+ ],
559
+ [
560
+ "An old man, portrayed as a retro superhero, stands in the streets of New York City at night"
561
+ ],
562
+ [
563
+ "many trees are surrounded by a lake in autumn colors, in the style of nature-inspired imagery, havencore, brightly colored, dark white and dark orange, bright primary colors, environmental activism, forestpunk --ar 64:51"
564
+ ],
565
+ [
566
+ "A fluffy mouse holding a watermelon, in a magical and colorful setting, illustrated in the style of Hayao Miyazaki anime by Studio Ghibli."
567
+ ],
568
+ [
569
+ "Inka warrior with a war make up, medium shot, natural light, Award winning wildlife photography, hyperrealistic, 8k resolution, --ar 9:16"
570
+ ],
571
+ [
572
+ "Character of lion in style of saiyan, mafia, gangsta, citylights background, Hyper detailed, hyper realistic, unreal engine ue5, cgi 3d, cinematic shot, 8k"
573
+ ],
574
+ [
575
+ "In the sky above, a giant, whimsical cloud shaped like the 😊 emoji casts a soft, golden light over the scene"
576
+ ],
577
+ [
578
+ "Cyberpunk eagle, neon ambiance, abstract black oil, gear mecha, detailed acrylic, grunge, intricate complexity, rendered in unreal engine 5, photorealistic, 8k"
579
+ ],
580
+ [
581
+ "close-up photo of a beautiful red rose breaking through a cube made of ice , splintered cracked ice surface, frosted colors, blood dripping from rose, melting ice, Valentine’s Day vibes, cinematic, sharp focus, intricate, cinematic, dramatic light"
582
+ ],
583
+ [
584
+ "3D cartoon Fox Head with Human Body, Wearing Iridescent Holographic Liquid Texture & Translucent Material Sun Protective Shirt, Boss Feel, Nike or Addidas Sun Protective Shirt, WitchPunk, Y2K Style, Green and blue, Blue, Metallic Feel, Strong Reflection, plain background, no background, pure single color background, Digital Fashion, Surreal Futurism, Supreme Kong NFT Artwork Style, disney style, headshot photography for portrait studio shoot, fashion editorial aesthetic, high resolution in the style of HAPE PRIME NFT, NFT 3D IP Feel, Bored Ape Yacht Club NFT project Feel, high detail, fine luster, 3D render, oc render, best quality, 8K, bright, front lighting, Face Shot, fine luster, ultra detailed"
585
+ ],
586
  ],
587
  [cap],
588
  label="Examples",
589
  )
590
 
591
+ @spaces.GPU(duration=240)
592
+ def on_submit(*infer_args):
593
+ result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
 
594
  if isinstance(result, ModelFailure):
595
  raise RuntimeError("Model failed to generate the image.")
596
  return result
 
611
  [output_img],
612
  )
613
 
 
614
  demo.queue(max_size=20).launch()
615
 
616
 
617
  if __name__ == "__main__":
 
618
  main()