cb1cyf commited on
Commit
56fe0df
·
1 Parent(s): b81fc8a

fix: format

Browse files
Files changed (1) hide show
  1. app.py +33 -26
app.py CHANGED
@@ -31,6 +31,7 @@ import torch
31
  from torchvision.transforms.functional import to_pil_image, to_tensor
32
 
33
  from accelerate import Accelerator
 
34
  from peft import LoraConfig
35
  from safetensors.torch import load_file
36
 
@@ -41,7 +42,6 @@ from omnigen2.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultiste
41
  from omnigen2.utils.img_util import create_collage
42
 
43
  NEGATIVE_PROMPT = "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar"
44
- ROOT_DIR = "projects/OmniGen2"
45
  SAVE_DIR = "output/gradio"
46
 
47
  pipeline = None
@@ -59,20 +59,22 @@ def load_pipeline(accelerator, weight_dtype, args):
59
  subfolder="transformer",
60
  torch_dtype=weight_dtype,
61
  )
62
- if args.lora_path is not None:
63
- target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
64
- lora_config = LoraConfig(
65
- r=512,
66
- lora_alpha=512,
67
- lora_dropout=0,
68
- init_lora_weights="gaussian",
69
- target_modules=target_modules,
70
- )
71
- pipeline.transformer.add_adapter(lora_config)
72
- lora_state_dict = load_file(args.lora_path, device=accelerator.device.__str__())
73
- pipeline.transformer.load_state_dict(lora_state_dict, strict=False)
74
- pipeline.transformer.fuse_lora(lora_scale=1, safe_fusing=False, adapter_names=["default"])
75
- pipeline.transformer.unload_lora()
 
 
76
  if args.enable_sequential_cpu_offload:
77
  pipeline.enable_sequential_cpu_offload()
78
  elif args.enable_model_cpu_offload:
@@ -153,6 +155,7 @@ def run(
153
  vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
154
  output_image = create_collage(vis_images)
155
 
 
156
  if save_images:
157
  # Create outputs directory if it doesn't exist
158
  output_dir = SAVE_DIR
@@ -171,7 +174,7 @@ def run(
171
  for i, image in enumerate(results.images):
172
  image_name, ext = os.path.splitext(output_path)
173
  image.save(f"{image_name}_{i}{ext}")
174
- return output_image
175
 
176
 
177
  def get_examples(base_dir="assets/examples/OmniGen2"):
@@ -199,9 +202,8 @@ badges_text = r"""
199
  <div style="text-align: center; display: flex; justify-content: center; gap: 5px;">
200
  <a href="https://github.com/bytedance/UMO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UMO"></a>
201
  <a href="https://bytedance.github.io/UMO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UMO-yellow"></a>
202
- <a href="https://arxiv.org/abs/25xx.xxxxx"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UMO-b31b1b.svg"></a>
203
  <a href="https://huggingface.co/bytedance-research/UMO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
204
- <a href="https://huggingface.co/spaces/bytedance-research/UMO-FLUX"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=demo&color=orange"></a>
205
  </div>
206
  """.strip()
207
 
@@ -227,11 +229,14 @@ tips = """
227
 
228
  article = """
229
  ```bibtex
230
- @article{cheng2025umo,
231
- title={UMO: Scaling Multi-Identity Consistency for Image Customization via Matching Reward},
232
- author={Cheng, Yufeng and Wu, Wenxu and Wu, Shaojin and Huang, Mengqi and Ding, Fei and He, Qian},
233
- journal={arXiv preprint arXiv:25xx.xxxxx},
234
- year={2025}
 
 
 
235
  }
236
  ```
237
  """.strip()
@@ -384,9 +389,11 @@ def main(args):
384
  # output image
385
  output_image = gr.Image(label="Output Image")
386
  global save_images
387
- save_images = gr.Checkbox(label="Save generated images", value=True)
 
388
  with gr.Accordion("Examples Comparison with OmniGen2", open=False):
389
  output_image_omnigen2 = gr.Image(label="Generated Image (OmniGen2)")
 
390
 
391
  gr.Markdown(star)
392
 
@@ -422,7 +429,7 @@ def main(args):
422
  seed_input,
423
  align_res,
424
  ],
425
- outputs=output_image,
426
  )
427
 
428
  gr.Examples(
@@ -444,7 +451,7 @@ def main(args):
444
  )
445
 
446
  # launch
447
- demo.launch(share=args.share, server_port=args.port, allowed_paths=[ROOT_DIR], server_name=args.server_name)
448
 
449
  def parse_args():
450
  parser = argparse.ArgumentParser()
 
31
  from torchvision.transforms.functional import to_pil_image, to_tensor
32
 
33
  from accelerate import Accelerator
34
+ from huggingface_hub import hf_hub_download
35
  from peft import LoraConfig
36
  from safetensors.torch import load_file
37
 
 
42
  from omnigen2.utils.img_util import create_collage
43
 
44
  NEGATIVE_PROMPT = "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar"
 
45
  SAVE_DIR = "output/gradio"
46
 
47
  pipeline = None
 
59
  subfolder="transformer",
60
  torch_dtype=weight_dtype,
61
  )
62
+
63
+ lora_path = hf_hub_download("bytedance-research/UMO", "UMO_OmniGen2.safetensors") if args.lora_path is None else args.lora_path
64
+ target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
65
+ lora_config = LoraConfig(
66
+ r=512,
67
+ lora_alpha=512,
68
+ lora_dropout=0,
69
+ init_lora_weights="gaussian",
70
+ target_modules=target_modules,
71
+ )
72
+ pipeline.transformer.add_adapter(lora_config)
73
+ lora_state_dict = load_file(lora_path, device=accelerator.device.__str__())
74
+ pipeline.transformer.load_state_dict(lora_state_dict, strict=False)
75
+ pipeline.transformer.fuse_lora(lora_scale=1, safe_fusing=False, adapter_names=["default"])
76
+ pipeline.transformer.unload_lora()
77
+
78
  if args.enable_sequential_cpu_offload:
79
  pipeline.enable_sequential_cpu_offload()
80
  elif args.enable_model_cpu_offload:
 
155
  vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
156
  output_image = create_collage(vis_images)
157
 
158
+ output_path = ""
159
  if save_images:
160
  # Create outputs directory if it doesn't exist
161
  output_dir = SAVE_DIR
 
174
  for i, image in enumerate(results.images):
175
  image_name, ext = os.path.splitext(output_path)
176
  image.save(f"{image_name}_{i}{ext}")
177
+ return output_image, output_path
178
 
179
 
180
  def get_examples(base_dir="assets/examples/OmniGen2"):
 
202
  <div style="text-align: center; display: flex; justify-content: center; gap: 5px;">
203
  <a href="https://github.com/bytedance/UMO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UMO"></a>
204
  <a href="https://bytedance.github.io/UMO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UMO-yellow"></a>
205
+ <a href="https://arxiv.org/abs/2509.06818"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UMO-b31b1b.svg"></a>
206
  <a href="https://huggingface.co/bytedance-research/UMO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
 
207
  </div>
208
  """.strip()
209
 
 
229
 
230
  article = """
231
  ```bibtex
232
+ @misc{cheng2025umoscalingmultiidentityconsistency,
233
+ title={UMO: Scaling Multi-Identity Consistency for Image Customization via Matching Reward},
234
+ author={Yufeng Cheng and Wenxu Wu and Shaojin Wu and Mengqi Huang and Fei Ding and Qian He},
235
+ year={2025},
236
+ eprint={2509.06818},
237
+ archivePrefix={arXiv},
238
+ primaryClass={cs.CV},
239
+ url={https://arxiv.org/abs/2509.06818},
240
  }
241
  ```
242
  """.strip()
 
389
  # output image
390
  output_image = gr.Image(label="Output Image")
391
  global save_images
392
+ # save_images = gr.Checkbox(label="Save generated images", value=True)
393
+ save_images = True
394
  with gr.Accordion("Examples Comparison with OmniGen2", open=False):
395
  output_image_omnigen2 = gr.Image(label="Generated Image (OmniGen2)")
396
+ download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
397
 
398
  gr.Markdown(star)
399
 
 
429
  seed_input,
430
  align_res,
431
  ],
432
+ outputs=[output_image, download_btn],
433
  )
434
 
435
  gr.Examples(
 
451
  )
452
 
453
  # launch
454
+ demo.launch(share=args.share, server_port=args.port, server_name=args.server_name, ssr_mode=False)
455
 
456
  def parse_args():
457
  parser = argparse.ArgumentParser()