more examples
Browse files- OmniGen/__pycache__/__init__.cpython-310.pyc +0 -0
- OmniGen/__pycache__/model.cpython-310.pyc +0 -0
- OmniGen/__pycache__/pipeline.cpython-310.pyc +0 -0
- OmniGen/__pycache__/processor.cpython-310.pyc +0 -0
- OmniGen/__pycache__/scheduler.cpython-310.pyc +0 -0
- OmniGen/__pycache__/transformer.cpython-310.pyc +0 -0
- OmniGen/__pycache__/utils.cpython-310.pyc +0 -0
- OmniGen/pipeline.py +5 -2
- OmniGen/train_helper/__init__.py +2 -0
- OmniGen/train_helper/data.py +116 -0
- OmniGen/train_helper/loss.py +68 -0
- app.py +87 -16
- imgs/demo_cases/edit.png +2 -2
- imgs/demo_cases/entity.png +2 -2
- imgs/demo_cases/reasoning.png +2 -2
- imgs/demo_cases/same_pose.png +2 -2
- imgs/demo_cases/skeletal.png +2 -2
- imgs/demo_cases/skeletal2img.png +2 -2
- imgs/{demo_cases.png → test_cases/1.jpg} +2 -2
- imgs/{overall.jpg → test_cases/2.jpg} +2 -2
- imgs/test_cases/3.jpg +3 -0
- imgs/test_cases/4.jpg +3 -0
- imgs/test_cases/Amanda.jpg +3 -0
- imgs/test_cases/icl1.jpg +3 -0
- imgs/test_cases/icl2.jpg +3 -0
- imgs/test_cases/icl3.jpg +3 -0
- imgs/test_cases/mckenna.jpg +3 -0
- imgs/test_cases/rose.jpg +3 -0
- imgs/test_cases/vase.jpg +3 -0
- imgs/test_cases/zhang.png +3 -0
OmniGen/__pycache__/__init__.cpython-310.pyc
CHANGED
Binary files a/OmniGen/__pycache__/__init__.cpython-310.pyc and b/OmniGen/__pycache__/__init__.cpython-310.pyc differ
|
|
OmniGen/__pycache__/model.cpython-310.pyc
CHANGED
Binary files a/OmniGen/__pycache__/model.cpython-310.pyc and b/OmniGen/__pycache__/model.cpython-310.pyc differ
|
|
OmniGen/__pycache__/pipeline.cpython-310.pyc
CHANGED
Binary files a/OmniGen/__pycache__/pipeline.cpython-310.pyc and b/OmniGen/__pycache__/pipeline.cpython-310.pyc differ
|
|
OmniGen/__pycache__/processor.cpython-310.pyc
CHANGED
Binary files a/OmniGen/__pycache__/processor.cpython-310.pyc and b/OmniGen/__pycache__/processor.cpython-310.pyc differ
|
|
OmniGen/__pycache__/scheduler.cpython-310.pyc
CHANGED
Binary files a/OmniGen/__pycache__/scheduler.cpython-310.pyc and b/OmniGen/__pycache__/scheduler.cpython-310.pyc differ
|
|
OmniGen/__pycache__/transformer.cpython-310.pyc
CHANGED
Binary files a/OmniGen/__pycache__/transformer.cpython-310.pyc and b/OmniGen/__pycache__/transformer.cpython-310.pyc differ
|
|
OmniGen/__pycache__/utils.cpython-310.pyc
CHANGED
Binary files a/OmniGen/__pycache__/utils.cpython-310.pyc and b/OmniGen/__pycache__/utils.cpython-310.pyc differ
|
|
OmniGen/pipeline.py
CHANGED
@@ -16,6 +16,7 @@ from diffusers.utils import (
|
|
16 |
scale_lora_layers,
|
17 |
unscale_lora_layers,
|
18 |
)
|
|
|
19 |
|
20 |
from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
|
21 |
|
@@ -59,12 +60,12 @@ class OmniGenPipeline:
|
|
59 |
|
60 |
@classmethod
|
61 |
def from_pretrained(cls, model_name, vae_path: str=None):
|
62 |
-
if not os.path.exists(model_name):
|
63 |
logger.info("Model not found, downloading...")
|
64 |
cache_folder = os.getenv('HF_HUB_CACHE')
|
65 |
model_name = snapshot_download(repo_id=model_name,
|
66 |
cache_dir=cache_folder,
|
67 |
-
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
68 |
logger.info(f"Downloaded model to {model_name}")
|
69 |
model = OmniGen.from_pretrained(model_name)
|
70 |
processor = OmniGenProcessor.from_pretrained(model_name)
|
@@ -82,6 +83,8 @@ class OmniGenPipeline:
|
|
82 |
def merge_lora(self, lora_path: str):
|
83 |
model = PeftModel.from_pretrained(self.model, lora_path)
|
84 |
model.merge_and_unload()
|
|
|
|
|
85 |
self.model = model
|
86 |
|
87 |
def to(self, device: Union[str, torch.device]):
|
|
|
16 |
scale_lora_layers,
|
17 |
unscale_lora_layers,
|
18 |
)
|
19 |
+
from safetensors.torch import load_file
|
20 |
|
21 |
from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
|
22 |
|
|
|
60 |
|
61 |
@classmethod
|
62 |
def from_pretrained(cls, model_name, vae_path: str=None):
|
63 |
+
if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"):
|
64 |
logger.info("Model not found, downloading...")
|
65 |
cache_folder = os.getenv('HF_HUB_CACHE')
|
66 |
model_name = snapshot_download(repo_id=model_name,
|
67 |
cache_dir=cache_folder,
|
68 |
+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt'])
|
69 |
logger.info(f"Downloaded model to {model_name}")
|
70 |
model = OmniGen.from_pretrained(model_name)
|
71 |
processor = OmniGenProcessor.from_pretrained(model_name)
|
|
|
83 |
def merge_lora(self, lora_path: str):
|
84 |
model = PeftModel.from_pretrained(self.model, lora_path)
|
85 |
model.merge_and_unload()
|
86 |
+
|
87 |
+
|
88 |
self.model = model
|
89 |
|
90 |
def to(self, device: Union[str, torch.device]):
|
OmniGen/train_helper/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .data import DatasetFromJson, TrainDataCollator
|
2 |
+
from .loss import training_losses
|
OmniGen/train_helper/data.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import datasets
|
3 |
+
from datasets import load_dataset, ClassLabel, concatenate_datasets
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
from PIL import Image
|
8 |
+
import json
|
9 |
+
import copy
|
10 |
+
# import torchvision.transforms as T
|
11 |
+
from torchvision import transforms
|
12 |
+
import pickle
|
13 |
+
import re
|
14 |
+
|
15 |
+
from OmniGen import OmniGenProcessor
|
16 |
+
from OmniGen.processor import OmniGenCollator
|
17 |
+
|
18 |
+
|
19 |
+
class DatasetFromJson(torch.utils.data.Dataset):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
json_file: str,
|
23 |
+
image_path: str,
|
24 |
+
processer: OmniGenProcessor,
|
25 |
+
image_transform,
|
26 |
+
max_input_length_limit: int = 18000,
|
27 |
+
condition_dropout_prob: float = 0.1,
|
28 |
+
keep_raw_resolution: bool = True,
|
29 |
+
):
|
30 |
+
|
31 |
+
self.image_transform = image_transform
|
32 |
+
self.processer = processer
|
33 |
+
self.condition_dropout_prob = condition_dropout_prob
|
34 |
+
self.max_input_length_limit = max_input_length_limit
|
35 |
+
self.keep_raw_resolution = keep_raw_resolution
|
36 |
+
|
37 |
+
self.data = load_dataset('json', data_files=json_file)['train']
|
38 |
+
self.image_path = image_path
|
39 |
+
|
40 |
+
def process_image(self, image_file):
|
41 |
+
if self.image_path is not None:
|
42 |
+
image_file = os.path.join(self.image_path, image_file)
|
43 |
+
image = Image.open(image_file).convert('RGB')
|
44 |
+
return self.image_transform(image)
|
45 |
+
|
46 |
+
def get_example(self, index):
|
47 |
+
example = self.data[index]
|
48 |
+
|
49 |
+
instruction, input_images, output_image = example['instruction'], example['input_images'], example['output_image']
|
50 |
+
if random.random() < self.condition_dropout_prob:
|
51 |
+
instruction = '<cfg>'
|
52 |
+
input_images = None
|
53 |
+
if input_images is not None:
|
54 |
+
input_images = [self.process_image(x) for x in input_images]
|
55 |
+
mllm_input = self.processer.process_multi_modal_prompt(instruction, input_images)
|
56 |
+
|
57 |
+
output_image = self.process_image(output_image)
|
58 |
+
|
59 |
+
return (mllm_input, output_image)
|
60 |
+
|
61 |
+
|
62 |
+
def __getitem__(self, index):
|
63 |
+
return self.get_example(index)
|
64 |
+
for _ in range(8):
|
65 |
+
try:
|
66 |
+
mllm_input, output_image = self.get_example(index)
|
67 |
+
if len(mllm_input['input_ids']) > self.max_input_length_limit:
|
68 |
+
raise RuntimeError(f"cur number of tokens={len(mllm_input['input_ids'])}, larger than max_input_length_limit={self.max_input_length_limit}")
|
69 |
+
return mllm_input, output_image
|
70 |
+
except Exception as e:
|
71 |
+
print("error when loading data: ", e)
|
72 |
+
print(self.data[index])
|
73 |
+
index = random.randint(0, len(self.data)-1)
|
74 |
+
raise RuntimeError("Too many bad data.")
|
75 |
+
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
return len(self.data)
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
class TrainDataCollator(OmniGenCollator):
|
83 |
+
def __init__(self, pad_token_id: int, hidden_size: int, keep_raw_resolution: bool):
|
84 |
+
self.pad_token_id = pad_token_id
|
85 |
+
self.hidden_size = hidden_size
|
86 |
+
self.keep_raw_resolution = keep_raw_resolution
|
87 |
+
|
88 |
+
def __call__(self, features):
|
89 |
+
mllm_inputs = [f[0] for f in features]
|
90 |
+
|
91 |
+
output_images = [f[1].unsqueeze(0) for f in features]
|
92 |
+
target_img_size = [[x.size(-2), x.size(-1)] for x in output_images]
|
93 |
+
|
94 |
+
all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
|
95 |
+
|
96 |
+
if not self.keep_raw_resolution:
|
97 |
+
output_image = torch.cat(output_image, dim=0)
|
98 |
+
if len(pixel_values) > 0:
|
99 |
+
all_pixel_values = torch.cat(all_pixel_values, dim=0)
|
100 |
+
else:
|
101 |
+
all_pixel_values = None
|
102 |
+
|
103 |
+
data = {"input_ids": all_padded_input_ids,
|
104 |
+
"attention_mask": all_attention_mask,
|
105 |
+
"position_ids": all_position_ids,
|
106 |
+
"input_pixel_values": all_pixel_values,
|
107 |
+
"input_image_sizes": all_image_sizes,
|
108 |
+
"padding_images": all_padding_images,
|
109 |
+
"output_images": output_images,
|
110 |
+
}
|
111 |
+
return data
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
OmniGen/train_helper/loss.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def sample_x0(x1):
|
5 |
+
"""Sampling x0 & t based on shape of x1 (if needed)
|
6 |
+
Args:
|
7 |
+
x1 - data point; [batch, *dim]
|
8 |
+
"""
|
9 |
+
if isinstance(x1, (list, tuple)):
|
10 |
+
x0 = [torch.randn_like(img_start) for img_start in x1]
|
11 |
+
else:
|
12 |
+
x0 = torch.randn_like(x1)
|
13 |
+
|
14 |
+
return x0
|
15 |
+
|
16 |
+
def sample_timestep(x1):
|
17 |
+
u = torch.normal(mean=0.0, std=1.0, size=(len(x1),))
|
18 |
+
t = 1 / (1 + torch.exp(-u))
|
19 |
+
t = t.to(x1[0])
|
20 |
+
return t
|
21 |
+
|
22 |
+
|
23 |
+
def training_losses(model, x1, model_kwargs=None, snr_type='uniform'):
|
24 |
+
"""Loss for training torche score model
|
25 |
+
Args:
|
26 |
+
- model: backbone model; could be score, noise, or velocity
|
27 |
+
- x1: datapoint
|
28 |
+
- model_kwargs: additional arguments for torche model
|
29 |
+
"""
|
30 |
+
if model_kwargs == None:
|
31 |
+
model_kwargs = {}
|
32 |
+
|
33 |
+
B = len(x1)
|
34 |
+
|
35 |
+
x0 = sample_x0(x1)
|
36 |
+
t = sample_timestep(x1)
|
37 |
+
|
38 |
+
if isinstance(x1, (list, tuple)):
|
39 |
+
xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)]
|
40 |
+
ut = [x1[i] - x0[i] for i in range(B)]
|
41 |
+
else:
|
42 |
+
dims = [1] * (len(x1.size()) - 1)
|
43 |
+
t_ = t.view(t.size(0), *dims)
|
44 |
+
xt = t_ * x1 + (1 - t_) * x0
|
45 |
+
ut = x1 - x0
|
46 |
+
|
47 |
+
model_output = model(xt, t, **model_kwargs)
|
48 |
+
|
49 |
+
terms = {}
|
50 |
+
|
51 |
+
if isinstance(x1, (list, tuple)):
|
52 |
+
assert len(model_output) == len(ut) == len(x1)
|
53 |
+
for i in range(B):
|
54 |
+
terms["loss"] = torch.stack(
|
55 |
+
[((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
|
56 |
+
dim=0,
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
terms["loss"] = mean_flat(((model_output - ut) ** 2))
|
60 |
+
|
61 |
+
return terms
|
62 |
+
|
63 |
+
|
64 |
+
def mean_flat(x):
|
65 |
+
"""
|
66 |
+
Take torche mean over all non-batch dimensions.
|
67 |
+
"""
|
68 |
+
return torch.mean(x, dim=list(range(1, len(x.size()))))
|
app.py
CHANGED
@@ -11,7 +11,7 @@ pipe = OmniGenPipeline.from_pretrained(
|
|
11 |
|
12 |
@spaces.GPU(duration=180)
|
13 |
# 示例处理函数:生成图像
|
14 |
-
def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
|
15 |
input_images = [img1, img2, img3]
|
16 |
# 去除 None
|
17 |
input_images = [img for img in input_images if img is not None]
|
@@ -26,7 +26,7 @@ def generate_image(text, img1, img2, img3, height, width, guidance_scale, infere
|
|
26 |
guidance_scale=guidance_scale,
|
27 |
img_guidance_scale=1.6,
|
28 |
num_inference_steps=inference_steps,
|
29 |
-
separate_cfg_infer=True,
|
30 |
use_kv_cache=False,
|
31 |
seed=seed,
|
32 |
)
|
@@ -47,26 +47,28 @@ def generate_image(text, img1, img2, img3, height, width, guidance_scale, infere
|
|
47 |
def get_example():
|
48 |
case = [
|
49 |
[
|
50 |
-
"A
|
51 |
None,
|
52 |
None,
|
53 |
None,
|
54 |
1024,
|
55 |
1024,
|
56 |
2.5,
|
|
|
57 |
50,
|
58 |
0,
|
59 |
],
|
60 |
[
|
61 |
-
"
|
62 |
"./imgs/test_cases/yifei2.png",
|
63 |
None,
|
64 |
None,
|
65 |
1024,
|
66 |
1024,
|
67 |
2.5,
|
|
|
68 |
50,
|
69 |
-
|
70 |
],
|
71 |
[
|
72 |
"A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
|
@@ -76,17 +78,55 @@ def get_example():
|
|
76 |
1024,
|
77 |
1024,
|
78 |
2.5,
|
|
|
79 |
50,
|
80 |
0,
|
81 |
],
|
82 |
[
|
83 |
-
"Two
|
84 |
-
"./imgs/test_cases/
|
85 |
-
"./imgs/test_cases/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
None,
|
87 |
1024,
|
88 |
1024,
|
89 |
2.5,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
50,
|
91 |
0,
|
92 |
],
|
@@ -98,6 +138,7 @@ def get_example():
|
|
98 |
1024,
|
99 |
1024,
|
100 |
2.5,
|
|
|
101 |
50,
|
102 |
222,
|
103 |
],
|
@@ -109,6 +150,7 @@ def get_example():
|
|
109 |
1024,
|
110 |
1024,
|
111 |
2.0,
|
|
|
112 |
50,
|
113 |
0,
|
114 |
],
|
@@ -120,6 +162,7 @@ def get_example():
|
|
120 |
1024,
|
121 |
1024,
|
122 |
2,
|
|
|
123 |
50,
|
124 |
42,
|
125 |
],
|
@@ -131,9 +174,22 @@ def get_example():
|
|
131 |
1024,
|
132 |
1024,
|
133 |
2.0,
|
|
|
134 |
50,
|
135 |
123,
|
136 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
[
|
138 |
"<img><|image_1|><\/img> What item can be used to see the current time? Please remove it.",
|
139 |
"./imgs/test_cases/watch.jpg",
|
@@ -142,25 +198,27 @@ def get_example():
|
|
142 |
1024,
|
143 |
1024,
|
144 |
2.5,
|
|
|
145 |
50,
|
146 |
0,
|
147 |
],
|
148 |
[
|
149 |
-
"
|
150 |
-
"./imgs/test_cases/
|
151 |
-
"./imgs/test_cases/
|
152 |
-
"./imgs/test_cases/
|
153 |
1024,
|
154 |
1024,
|
155 |
2.5,
|
|
|
156 |
50,
|
157 |
-
|
158 |
],
|
159 |
]
|
160 |
return case
|
161 |
|
162 |
-
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
|
163 |
-
return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed)
|
164 |
|
165 |
description = """
|
166 |
OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.
|
@@ -168,6 +226,13 @@ OmniGen is a unified image generation model that you can use to perform various
|
|
168 |
For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>` (for the first image, the placeholder is <img><|image_1|></img>. for the second image, the the placeholder is <img><|image_2|></img>).
|
169 |
For example, use an image of a woman to generate a new image:
|
170 |
prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is \<img\>\<|image_1|\>\</img\>."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
"""
|
172 |
|
173 |
# Gradio 接口
|
@@ -197,7 +262,11 @@ with gr.Blocks() as demo:
|
|
197 |
|
198 |
# 引导尺度输入
|
199 |
guidance_scale_input = gr.Slider(
|
200 |
-
label="Guidance Scale", minimum=1.0, maximum=
|
|
|
|
|
|
|
|
|
201 |
)
|
202 |
|
203 |
num_inference_steps = gr.Slider(
|
@@ -226,6 +295,7 @@ with gr.Blocks() as demo:
|
|
226 |
height_input,
|
227 |
width_input,
|
228 |
guidance_scale_input,
|
|
|
229 |
num_inference_steps,
|
230 |
seed_input,
|
231 |
],
|
@@ -243,6 +313,7 @@ with gr.Blocks() as demo:
|
|
243 |
height_input,
|
244 |
width_input,
|
245 |
guidance_scale_input,
|
|
|
246 |
num_inference_steps,
|
247 |
seed_input,
|
248 |
],
|
|
|
11 |
|
12 |
@spaces.GPU(duration=180)
|
13 |
# 示例处理函数:生成图像
|
14 |
+
def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed):
|
15 |
input_images = [img1, img2, img3]
|
16 |
# 去除 None
|
17 |
input_images = [img for img in input_images if img is not None]
|
|
|
26 |
guidance_scale=guidance_scale,
|
27 |
img_guidance_scale=1.6,
|
28 |
num_inference_steps=inference_steps,
|
29 |
+
separate_cfg_infer=True, # set False can speed up the inference process
|
30 |
use_kv_cache=False,
|
31 |
seed=seed,
|
32 |
)
|
|
|
47 |
def get_example():
|
48 |
case = [
|
49 |
[
|
50 |
+
"A curly-haired man in a red shirt is drinking tea.",
|
51 |
None,
|
52 |
None,
|
53 |
None,
|
54 |
1024,
|
55 |
1024,
|
56 |
2.5,
|
57 |
+
1.6,
|
58 |
50,
|
59 |
0,
|
60 |
],
|
61 |
[
|
62 |
+
"The woman in <img><|image_1|></img> waves her hand happily in the crowd",
|
63 |
"./imgs/test_cases/yifei2.png",
|
64 |
None,
|
65 |
None,
|
66 |
1024,
|
67 |
1024,
|
68 |
2.5,
|
69 |
+
1.9,
|
70 |
50,
|
71 |
+
128,
|
72 |
],
|
73 |
[
|
74 |
"A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
|
|
|
78 |
1024,
|
79 |
1024,
|
80 |
2.5,
|
81 |
+
1.6,
|
82 |
50,
|
83 |
0,
|
84 |
],
|
85 |
[
|
86 |
+
"Two woman are raising fried chicken legs in a bar. A woman is <img><|image_1|></img>. The other woman is <img><|image_2|></img>.",
|
87 |
+
"./imgs/test_cases/mckenna.jpg",
|
88 |
+
"./imgs/test_cases/Amanda.jpg",
|
89 |
+
None,
|
90 |
+
1024,
|
91 |
+
1024,
|
92 |
+
2.5,
|
93 |
+
1.8,
|
94 |
+
50,
|
95 |
+
168,
|
96 |
+
],
|
97 |
+
[
|
98 |
+
"A man and a short-haired woman with a wrinkled face are standing in front of a bookshelf in a library. The man is the man in the middle of <img><|image_1|></img>, and the woman is oldest woman in <img><|image_2|></img>",
|
99 |
+
"./imgs/test_cases/1.jpg",
|
100 |
+
"./imgs/test_cases/2.jpg",
|
101 |
+
None,
|
102 |
+
1024,
|
103 |
+
1024,
|
104 |
+
2.5,
|
105 |
+
1.6,
|
106 |
+
50,
|
107 |
+
60,
|
108 |
+
],
|
109 |
+
[
|
110 |
+
"A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <img><|image_1|></img>. The woman is the woman on the left of <img><|image_2|></img>",
|
111 |
+
"./imgs/test_cases/3.jpg",
|
112 |
+
"./imgs/test_cases/4.jpg",
|
113 |
None,
|
114 |
1024,
|
115 |
1024,
|
116 |
2.5,
|
117 |
+
1.8,
|
118 |
+
50,
|
119 |
+
66,
|
120 |
+
],
|
121 |
+
[
|
122 |
+
"The flower <img><|image_1|><\/img> is placed in the vase which is in the middle of <img><|image_2|><\/img> on a wooden table of a living room",
|
123 |
+
"./imgs/test_cases/rose.jpg",
|
124 |
+
"./imgs/test_cases/vase.jpg",
|
125 |
+
None,
|
126 |
+
1024,
|
127 |
+
1024,
|
128 |
+
2.5,
|
129 |
+
1.6,
|
130 |
50,
|
131 |
0,
|
132 |
],
|
|
|
138 |
1024,
|
139 |
1024,
|
140 |
2.5,
|
141 |
+
1.6,
|
142 |
50,
|
143 |
222,
|
144 |
],
|
|
|
150 |
1024,
|
151 |
1024,
|
152 |
2.0,
|
153 |
+
1.6,
|
154 |
50,
|
155 |
0,
|
156 |
],
|
|
|
162 |
1024,
|
163 |
1024,
|
164 |
2,
|
165 |
+
1.6,
|
166 |
50,
|
167 |
42,
|
168 |
],
|
|
|
174 |
1024,
|
175 |
1024,
|
176 |
2.0,
|
177 |
+
1.6,
|
178 |
50,
|
179 |
123,
|
180 |
],
|
181 |
+
[
|
182 |
+
"Following the depth mapping of this image <img><|image_1|><img>, generate a new photo: A young girl is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.",
|
183 |
+
"./imgs/demo_cases/edit.png",
|
184 |
+
None,
|
185 |
+
None,
|
186 |
+
1024,
|
187 |
+
1024,
|
188 |
+
2.0,
|
189 |
+
1.6,
|
190 |
+
50,
|
191 |
+
1,
|
192 |
+
],
|
193 |
[
|
194 |
"<img><|image_1|><\/img> What item can be used to see the current time? Please remove it.",
|
195 |
"./imgs/test_cases/watch.jpg",
|
|
|
198 |
1024,
|
199 |
1024,
|
200 |
2.5,
|
201 |
+
1.6,
|
202 |
50,
|
203 |
0,
|
204 |
],
|
205 |
[
|
206 |
+
"According to the following examples, generate an output for the input.\nInput: <img><|image_1|></img>\nOutput: <img><|image_2|></img>\n\nInput: <img><|image_3|></img>\nOutput: ",
|
207 |
+
"./imgs/test_cases/icl1.jpg",
|
208 |
+
"./imgs/test_cases/icl2.jpg",
|
209 |
+
"./imgs/test_cases/icl3.jpg",
|
210 |
1024,
|
211 |
1024,
|
212 |
2.5,
|
213 |
+
1.6,
|
214 |
50,
|
215 |
+
1,
|
216 |
],
|
217 |
]
|
218 |
return case
|
219 |
|
220 |
+
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed):
|
221 |
+
return generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed)
|
222 |
|
223 |
description = """
|
224 |
OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.
|
|
|
226 |
For multi-modal to image generation, you should pass a string as `prompt`, and a list of image paths as `input_images`. The placeholder in the prompt should be in the format of `<img><|image_*|></img>` (for the first image, the placeholder is <img><|image_1|></img>. for the second image, the the placeholder is <img><|image_2|></img>).
|
227 |
For example, use an image of a woman to generate a new image:
|
228 |
prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is \<img\>\<|image_1|\>\</img\>."
|
229 |
+
|
230 |
+
Tips:
|
231 |
+
- Oversaturated: If the image appears oversaturated, please reduce the `guidance_scale`.
|
232 |
+
- Low-quality: More detailed prompt will lead to better results.
|
233 |
+
- Animate Style: If the genereate images is in animate style, you can try to add `photo` to the prompt`.
|
234 |
+
- Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.
|
235 |
+
- For image editing tasks, we recommend placing the image before the editing instruction. For example, use `<img><|image_1|></img> remove suit`, rather than `remove suit <img><|image_1|></img>`.
|
236 |
"""
|
237 |
|
238 |
# Gradio 接口
|
|
|
262 |
|
263 |
# 引导尺度输入
|
264 |
guidance_scale_input = gr.Slider(
|
265 |
+
label="Guidance Scale", minimum=1.0, maximum=5.0, value=2.5, step=0.1
|
266 |
+
)
|
267 |
+
|
268 |
+
img_guidance_scale_input = gr.Slider(
|
269 |
+
label="img_guidance_scale", minimum=1.0, maximum=2.0, value=1.6, step=0.1
|
270 |
)
|
271 |
|
272 |
num_inference_steps = gr.Slider(
|
|
|
295 |
height_input,
|
296 |
width_input,
|
297 |
guidance_scale_input,
|
298 |
+
img_guidance_scale_input,
|
299 |
num_inference_steps,
|
300 |
seed_input,
|
301 |
],
|
|
|
313 |
height_input,
|
314 |
width_input,
|
315 |
guidance_scale_input,
|
316 |
+
img_guidance_scale_input,
|
317 |
num_inference_steps,
|
318 |
seed_input,
|
319 |
],
|
imgs/demo_cases/edit.png
CHANGED
Git LFS Details
|
Git LFS Details
|
imgs/demo_cases/entity.png
CHANGED
Git LFS Details
|
Git LFS Details
|
imgs/demo_cases/reasoning.png
CHANGED
Git LFS Details
|
Git LFS Details
|
imgs/demo_cases/same_pose.png
CHANGED
Git LFS Details
|
Git LFS Details
|
imgs/demo_cases/skeletal.png
CHANGED
Git LFS Details
|
Git LFS Details
|
imgs/demo_cases/skeletal2img.png
CHANGED
Git LFS Details
|
Git LFS Details
|
imgs/{demo_cases.png → test_cases/1.jpg}
RENAMED
File without changes
|
imgs/{overall.jpg → test_cases/2.jpg}
RENAMED
File without changes
|
imgs/test_cases/3.jpg
ADDED
Git LFS Details
|
imgs/test_cases/4.jpg
ADDED
Git LFS Details
|
imgs/test_cases/Amanda.jpg
ADDED
Git LFS Details
|
imgs/test_cases/icl1.jpg
ADDED
Git LFS Details
|
imgs/test_cases/icl2.jpg
ADDED
Git LFS Details
|
imgs/test_cases/icl3.jpg
ADDED
Git LFS Details
|
imgs/test_cases/mckenna.jpg
ADDED
Git LFS Details
|
imgs/test_cases/rose.jpg
ADDED
Git LFS Details
|
imgs/test_cases/vase.jpg
ADDED
Git LFS Details
|
imgs/test_cases/zhang.png
ADDED
Git LFS Details
|