yrr
commited on
Commit
•
7f48662
1
Parent(s):
2c2ec6c
test
Browse files- OmniGen/__init__.py +4 -0
- OmniGen/model.py +402 -0
- OmniGen/pipeline.py +201 -0
- OmniGen/processor.py +349 -0
- OmniGen/scheduler.py +55 -0
- OmniGen/train.py +0 -0
- OmniGen/transformer.py +159 -0
- app.py +59 -145
- edit.png +0 -0
- imgs/.DS_Store +3 -0
- imgs/test_cases/liuyifei.png +0 -0
- imgs/test_cases/taylor.png +0 -0
- imgs/test_cases/trump.png +0 -0
- imgs/test_cases/turing.png +0 -0
- inference.ipynb +0 -0
- setup.py +23 -0
OmniGen/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model import OmniGen
|
2 |
+
from .processor import OmniGenProcessor
|
3 |
+
from .scheduler import OmniGenScheduler
|
4 |
+
from .pipeline import OmniGenPipeline
|
OmniGen/model.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The code is revised from DiT
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
from typing import Dict
|
8 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
9 |
+
|
10 |
+
from OmniGen.transformer import Phi3Config, Phi3Transformer
|
11 |
+
|
12 |
+
|
13 |
+
def modulate(x, shift, scale):
|
14 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
15 |
+
|
16 |
+
|
17 |
+
class TimestepEmbedder(nn.Module):
|
18 |
+
"""
|
19 |
+
Embeds scalar timesteps into vector representations.
|
20 |
+
"""
|
21 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
22 |
+
super().__init__()
|
23 |
+
self.mlp = nn.Sequential(
|
24 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
25 |
+
nn.SiLU(),
|
26 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
27 |
+
)
|
28 |
+
self.frequency_embedding_size = frequency_embedding_size
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def timestep_embedding(t, dim, max_period=10000):
|
32 |
+
"""
|
33 |
+
Create sinusoidal timestep embeddings.
|
34 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
35 |
+
These may be fractional.
|
36 |
+
:param dim: the dimension of the output.
|
37 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
38 |
+
:return: an (N, D) Tensor of positional embeddings.
|
39 |
+
"""
|
40 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
41 |
+
half = dim // 2
|
42 |
+
freqs = torch.exp(
|
43 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
44 |
+
).to(device=t.device)
|
45 |
+
args = t[:, None].float() * freqs[None]
|
46 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
47 |
+
if dim % 2:
|
48 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
49 |
+
return embedding
|
50 |
+
|
51 |
+
def forward(self, t, dtype=torch.float32):
|
52 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
53 |
+
t_emb = self.mlp(t_freq)
|
54 |
+
return t_emb
|
55 |
+
|
56 |
+
|
57 |
+
class FinalLayer(nn.Module):
|
58 |
+
"""
|
59 |
+
The final layer of DiT.
|
60 |
+
"""
|
61 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
62 |
+
super().__init__()
|
63 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
64 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
65 |
+
self.adaLN_modulation = nn.Sequential(
|
66 |
+
nn.SiLU(),
|
67 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
68 |
+
)
|
69 |
+
|
70 |
+
def forward(self, x, c):
|
71 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
72 |
+
x = modulate(self.norm_final(x), shift, scale)
|
73 |
+
x = self.linear(x)
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
|
78 |
+
"""
|
79 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
80 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
81 |
+
"""
|
82 |
+
if isinstance(grid_size, int):
|
83 |
+
grid_size = (grid_size, grid_size)
|
84 |
+
|
85 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
86 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
87 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
88 |
+
grid = np.stack(grid, axis=0)
|
89 |
+
|
90 |
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
91 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
92 |
+
if cls_token and extra_tokens > 0:
|
93 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
94 |
+
return pos_embed
|
95 |
+
|
96 |
+
|
97 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
98 |
+
assert embed_dim % 2 == 0
|
99 |
+
|
100 |
+
# use half of dimensions to encode grid_h
|
101 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
102 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
103 |
+
|
104 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
105 |
+
return emb
|
106 |
+
|
107 |
+
|
108 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
109 |
+
"""
|
110 |
+
embed_dim: output dimension for each position
|
111 |
+
pos: a list of positions to be encoded: size (M,)
|
112 |
+
out: (M, D)
|
113 |
+
"""
|
114 |
+
assert embed_dim % 2 == 0
|
115 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
116 |
+
omega /= embed_dim / 2.
|
117 |
+
omega = 1. / 10000**omega # (D/2,)
|
118 |
+
|
119 |
+
pos = pos.reshape(-1) # (M,)
|
120 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
121 |
+
|
122 |
+
emb_sin = np.sin(out) # (M, D/2)
|
123 |
+
emb_cos = np.cos(out) # (M, D/2)
|
124 |
+
|
125 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
126 |
+
return emb
|
127 |
+
|
128 |
+
|
129 |
+
class PatchEmbedMR(nn.Module):
|
130 |
+
""" 2D Image to Patch Embedding
|
131 |
+
"""
|
132 |
+
def __init__(
|
133 |
+
self,
|
134 |
+
patch_size: int = 2,
|
135 |
+
in_chans: int = 4,
|
136 |
+
embed_dim: int = 768,
|
137 |
+
bias: bool = True,
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
x = self.proj(x)
|
144 |
+
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
145 |
+
return x
|
146 |
+
|
147 |
+
|
148 |
+
class OmniGen(nn.Module):
|
149 |
+
"""
|
150 |
+
Diffusion model with a Transformer backbone.
|
151 |
+
"""
|
152 |
+
def __init__(
|
153 |
+
self,
|
154 |
+
transformer_config: Phi3Config,
|
155 |
+
patch_size=2,
|
156 |
+
in_channels=4,
|
157 |
+
pe_interpolation: float = 1.0,
|
158 |
+
pos_embed_max_size: int = 192,
|
159 |
+
):
|
160 |
+
super().__init__()
|
161 |
+
self.in_channels = in_channels
|
162 |
+
self.out_channels = in_channels
|
163 |
+
self.patch_size = patch_size
|
164 |
+
self.pos_embed_max_size = pos_embed_max_size
|
165 |
+
|
166 |
+
hidden_size = transformer_config.hidden_size
|
167 |
+
|
168 |
+
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
169 |
+
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
170 |
+
|
171 |
+
self.time_token = TimestepEmbedder(hidden_size)
|
172 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
173 |
+
|
174 |
+
self.pe_interpolation = pe_interpolation
|
175 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
|
176 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
177 |
+
|
178 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
179 |
+
|
180 |
+
self.initialize_weights()
|
181 |
+
|
182 |
+
self.llm = Phi3Transformer(config=transformer_config)
|
183 |
+
self.llm.config.use_cache = False
|
184 |
+
|
185 |
+
@classmethod
|
186 |
+
def from_pretrained(cls, model_name):
|
187 |
+
if not os.path.exists(os.path.join(model_name, 'model.pt')):
|
188 |
+
cache_folder = os.getenv('HF_HUB_CACHE')
|
189 |
+
model_name = snapshot_download(repo_id=model_name,
|
190 |
+
cache_dir=cache_folder,
|
191 |
+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
192 |
+
config = Phi3Config.from_pretrained(model_name)
|
193 |
+
model = cls(config)
|
194 |
+
ckpt = torch.load(os.path.join(model_name, 'model.pt'))
|
195 |
+
model.load_state_dict(ckpt)
|
196 |
+
return model
|
197 |
+
|
198 |
+
def initialize_weights(self):
|
199 |
+
assert not hasattr(self, "llama")
|
200 |
+
|
201 |
+
# Initialize transformer layers:
|
202 |
+
def _basic_init(module):
|
203 |
+
if isinstance(module, nn.Linear):
|
204 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
205 |
+
if module.bias is not None:
|
206 |
+
nn.init.constant_(module.bias, 0)
|
207 |
+
self.apply(_basic_init)
|
208 |
+
|
209 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
210 |
+
w = self.x_embedder.proj.weight.data
|
211 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
212 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
213 |
+
|
214 |
+
w = self.input_x_embedder.proj.weight.data
|
215 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
216 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
217 |
+
|
218 |
+
|
219 |
+
# Initialize timestep embedding MLP:
|
220 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
221 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
222 |
+
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
|
223 |
+
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
|
224 |
+
|
225 |
+
# Zero-out output layers:
|
226 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
227 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
228 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
229 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
230 |
+
|
231 |
+
def unpatchify(self, x, h, w):
|
232 |
+
"""
|
233 |
+
x: (N, T, patch_size**2 * C)
|
234 |
+
imgs: (N, H, W, C)
|
235 |
+
"""
|
236 |
+
c = self.out_channels
|
237 |
+
|
238 |
+
x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
|
239 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
240 |
+
imgs = x.reshape(shape=(x.shape[0], c, h, w))
|
241 |
+
return imgs
|
242 |
+
|
243 |
+
|
244 |
+
def cropped_pos_embed(self, height, width):
|
245 |
+
"""Crops positional embeddings for SD3 compatibility."""
|
246 |
+
if self.pos_embed_max_size is None:
|
247 |
+
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
248 |
+
|
249 |
+
height = height // self.patch_size
|
250 |
+
width = width // self.patch_size
|
251 |
+
if height > self.pos_embed_max_size:
|
252 |
+
raise ValueError(
|
253 |
+
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
254 |
+
)
|
255 |
+
if width > self.pos_embed_max_size:
|
256 |
+
raise ValueError(
|
257 |
+
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
258 |
+
)
|
259 |
+
|
260 |
+
top = (self.pos_embed_max_size - height) // 2
|
261 |
+
left = (self.pos_embed_max_size - width) // 2
|
262 |
+
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
263 |
+
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
264 |
+
# print(top, top + height, left, left + width, spatial_pos_embed.size())
|
265 |
+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
266 |
+
return spatial_pos_embed
|
267 |
+
|
268 |
+
|
269 |
+
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
|
270 |
+
if isinstance(latents, list):
|
271 |
+
return_list = False
|
272 |
+
if padding_latent is None:
|
273 |
+
padding_latent = [None] * len(latents)
|
274 |
+
return_list = True
|
275 |
+
patched_latents, num_tokens, shapes = [], [], []
|
276 |
+
for latent, padding in zip(latents, padding_latent):
|
277 |
+
height, width = latent.shape[-2:]
|
278 |
+
if is_input_images:
|
279 |
+
latent = self.input_x_embedder(latent)
|
280 |
+
else:
|
281 |
+
latent = self.x_embedder(latent)
|
282 |
+
pos_embed = self.cropped_pos_embed(height, width)
|
283 |
+
latent = latent + pos_embed
|
284 |
+
if padding is not None:
|
285 |
+
latent = torch.cat([latent, padding], dim=-2)
|
286 |
+
patched_latents.append(latent)
|
287 |
+
|
288 |
+
num_tokens.append(pos_embed.size(1))
|
289 |
+
shapes.append([height, width])
|
290 |
+
if not return_list:
|
291 |
+
latents = torch.cat(patched_latents, dim=0)
|
292 |
+
else:
|
293 |
+
latents = patched_latents
|
294 |
+
else:
|
295 |
+
height, width = latents.shape[-2:]
|
296 |
+
if is_input_images:
|
297 |
+
latents = self.input_x_embedder(latents)
|
298 |
+
else:
|
299 |
+
latents = self.x_embedder(latents)
|
300 |
+
pos_embed = self.cropped_pos_embed(height, width)
|
301 |
+
latents = latents + pos_embed
|
302 |
+
num_tokens = latents.size(1)
|
303 |
+
shapes = [height, width]
|
304 |
+
return latents, num_tokens, shapes
|
305 |
+
|
306 |
+
|
307 |
+
def forward(self, x, timestep, text_ids, pixel_values, image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None):
|
308 |
+
"""
|
309 |
+
|
310 |
+
"""
|
311 |
+
input_is_list = isinstance(x, list)
|
312 |
+
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
313 |
+
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
314 |
+
|
315 |
+
if pixel_values is not None:
|
316 |
+
input_latents, _, _ = self.patch_multiple_resolutions(pixel_values, is_input_images=True)
|
317 |
+
if text_ids is not None:
|
318 |
+
condition_embeds = self.llm.embed_tokens(text_ids)
|
319 |
+
input_img_inx = 0
|
320 |
+
for b_inx in image_sizes.keys():
|
321 |
+
for start_inx, end_inx in image_sizes[b_inx]:
|
322 |
+
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
323 |
+
input_img_inx += 1
|
324 |
+
if pixel_values is not None:
|
325 |
+
assert input_img_inx == len(input_latents)
|
326 |
+
|
327 |
+
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
328 |
+
else:
|
329 |
+
input_emb = torch.cat([time_token, x], dim=1)
|
330 |
+
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)
|
331 |
+
output, past_key_values = output.last_hidden_state, output.past_key_values
|
332 |
+
if input_is_list:
|
333 |
+
image_embedding = output[:, -max(num_tokens):]
|
334 |
+
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
335 |
+
x = self.final_layer(image_embedding, time_emb)
|
336 |
+
latents = []
|
337 |
+
for i in range(x.size(0)):
|
338 |
+
latent = x[i:i+1, :num_tokens[i]]
|
339 |
+
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
340 |
+
latents.append(latent)
|
341 |
+
else:
|
342 |
+
image_embedding = output[:, -num_tokens:]
|
343 |
+
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
344 |
+
x = self.final_layer(image_embedding, time_emb)
|
345 |
+
latents = self.unpatchify(x, shapes[0], shapes[1])
|
346 |
+
|
347 |
+
return latents, past_key_values
|
348 |
+
|
349 |
+
@torch.no_grad()
|
350 |
+
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
|
351 |
+
"""
|
352 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
353 |
+
"""
|
354 |
+
self.llm.config.use_cache = use_kv_cache
|
355 |
+
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values)
|
356 |
+
if use_img_cfg:
|
357 |
+
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
358 |
+
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
359 |
+
model_out = [cond, cond, cond]
|
360 |
+
else:
|
361 |
+
cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
|
362 |
+
cond = uncond + cfg_scale * (cond - uncond)
|
363 |
+
model_out = [cond, cond]
|
364 |
+
|
365 |
+
return torch.cat(model_out, dim=0), past_key_values
|
366 |
+
|
367 |
+
|
368 |
+
@torch.no_grad()
|
369 |
+
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
|
370 |
+
"""
|
371 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
372 |
+
"""
|
373 |
+
self.llm.config.use_cache = use_kv_cache
|
374 |
+
if past_key_values is None:
|
375 |
+
past_key_values = [None] * len(attention_mask)
|
376 |
+
|
377 |
+
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
378 |
+
timestep = timestep.to(x[0].dtype)
|
379 |
+
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
380 |
+
|
381 |
+
model_out, pask_key_values = [], []
|
382 |
+
for i in range(len(input_ids)):
|
383 |
+
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values[i])
|
384 |
+
model_out.append(temp_out)
|
385 |
+
pask_key_values.append(temp_pask_key_values)
|
386 |
+
|
387 |
+
if len(model_out) == 3:
|
388 |
+
cond, uncond, img_cond = model_out
|
389 |
+
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
390 |
+
model_out = [cond, cond, cond]
|
391 |
+
elif len(model_out) == 2:
|
392 |
+
cond, uncond = model_out
|
393 |
+
cond = uncond + cfg_scale * (cond - uncond)
|
394 |
+
model_out = [cond, cond]
|
395 |
+
else:
|
396 |
+
return model_out[0]
|
397 |
+
|
398 |
+
return torch.cat(model_out, dim=0), pask_key_values
|
399 |
+
|
400 |
+
|
401 |
+
|
402 |
+
|
OmniGen/pipeline.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import inspect
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from huggingface_hub import snapshot_download
|
9 |
+
from diffusers.models import AutoencoderKL
|
10 |
+
from diffusers.utils import (
|
11 |
+
USE_PEFT_BACKEND,
|
12 |
+
is_torch_xla_available,
|
13 |
+
logging,
|
14 |
+
replace_example_docstring,
|
15 |
+
scale_lora_layers,
|
16 |
+
unscale_lora_layers,
|
17 |
+
)
|
18 |
+
|
19 |
+
from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
EXAMPLE_DOC_STRING = """
|
25 |
+
Examples:
|
26 |
+
```py
|
27 |
+
>>> from OmniGen import OmniGenPipeline
|
28 |
+
>>> pipe = FluxControlNetPipeline.from_pretrained(
|
29 |
+
... base_model
|
30 |
+
... )
|
31 |
+
>>> prompt = "A woman holds a bouquet of flowers and faces the camera"
|
32 |
+
>>> image = pipe(
|
33 |
+
... prompt,
|
34 |
+
... guidance_scale=1.0,
|
35 |
+
... num_inference_steps=50,
|
36 |
+
... ).images[0]
|
37 |
+
>>> image.save("t2i.png")
|
38 |
+
```
|
39 |
+
"""
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
class OmniGenPipeline:
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
vae: AutoencoderKL,
|
47 |
+
model: OmniGen,
|
48 |
+
processor: OmniGenProcessor,
|
49 |
+
):
|
50 |
+
self.vae = vae
|
51 |
+
self.model = model
|
52 |
+
self.processor = processor
|
53 |
+
|
54 |
+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
55 |
+
self.model.to(self.device)
|
56 |
+
self.vae.to(self.device)
|
57 |
+
|
58 |
+
@classmethod
|
59 |
+
def from_pretrained(cls, model_name):
|
60 |
+
if not os.path.exists(model_name):
|
61 |
+
cache_folder = os.getenv('HF_HUB_CACHE')
|
62 |
+
print(cache_folder)
|
63 |
+
model_name = snapshot_download(repo_id=model_name,
|
64 |
+
cache_dir=cache_folder,
|
65 |
+
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
66 |
+
logger.info(f"Downloaded model to {model_name}")
|
67 |
+
model = OmniGen.from_pretrained(model_name)
|
68 |
+
processor = OmniGenProcessor.from_pretrained(model_name)
|
69 |
+
vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae"))
|
70 |
+
|
71 |
+
return cls(vae, model, processor)
|
72 |
+
|
73 |
+
def vae_encode(self, x, dtype):
|
74 |
+
if self.vae.config.shift_factor is not None:
|
75 |
+
x = self.vae.encode(x).latent_dist.sample()
|
76 |
+
x = (x - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
77 |
+
else:
|
78 |
+
x = self.vae.encode(x).latent_dist.sample().mul_(self.vae.config.scaling_factor)
|
79 |
+
x = x.to(dtype)
|
80 |
+
return x
|
81 |
+
|
82 |
+
def move_to_device(self, data):
|
83 |
+
if isinstance(data, list):
|
84 |
+
return [x.to(self.device) for x in data]
|
85 |
+
return data.to(self.device)
|
86 |
+
|
87 |
+
|
88 |
+
@torch.no_grad()
|
89 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
90 |
+
def __call__(
|
91 |
+
self,
|
92 |
+
prompt: Union[str, List[str]],
|
93 |
+
input_images: Union[List[str], List[List[str]]] = None,
|
94 |
+
height: int = 1024,
|
95 |
+
width: int = 1024,
|
96 |
+
num_inference_steps: int = 50,
|
97 |
+
guidance_scale: float = 3,
|
98 |
+
use_img_guidance: bool = True,
|
99 |
+
img_guidance_scale: float = 1.6,
|
100 |
+
separate_cfg_infer: bool = False,
|
101 |
+
use_kv_cache: bool = True,
|
102 |
+
dtype: torch.dtype = torch.bfloat16,
|
103 |
+
):
|
104 |
+
r"""
|
105 |
+
Function invoked when calling the pipeline for generation.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
prompt (`str` or `List[str]`):
|
109 |
+
The prompt or prompts to guide the image generation.
|
110 |
+
input_images (`List[str]` or `List[List[str]]`, *optional*):
|
111 |
+
The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list.
|
112 |
+
height (`int`, *optional*, defaults to 1024):
|
113 |
+
The height in pixels of the generated image. The number must be a multiple of 16.
|
114 |
+
width (`int`, *optional*, defaults to 1024):
|
115 |
+
The width in pixels of the generated image. The number must be a multiple of 16.
|
116 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
117 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
|
118 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
119 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
120 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
121 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
122 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
123 |
+
usually at the expense of lower image quality.
|
124 |
+
use_img_guidance (`bool`, *optional*, defaults to True):
|
125 |
+
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
|
126 |
+
img_guidance_scale (`float`, *optional*, defaults to 1.6):
|
127 |
+
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
|
128 |
+
separate_cfg_infer (`bool`, *optional*, defaults to False):
|
129 |
+
Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
|
130 |
+
use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
|
131 |
+
|
132 |
+
Examples:
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
A list with the generated images.
|
136 |
+
"""
|
137 |
+
assert height%16 == 0 and width%16 == 0
|
138 |
+
if use_kv_cache and separate_cfg_infer:
|
139 |
+
raise "Currently, don't support both use_kv_cache and separate_cfg_infer"
|
140 |
+
if input_images is None:
|
141 |
+
use_img_guidance = False
|
142 |
+
if isinstance(prompt, str):
|
143 |
+
prompt = [prompt]
|
144 |
+
input_images = [input_images] if input_images is not None else None
|
145 |
+
|
146 |
+
input_data = self.processor(prompt, input_images, height=height, width=width, use_img_cfg=use_img_guidance, separate_cfg_input=separate_cfg_infer)
|
147 |
+
|
148 |
+
num_prompt = len(prompt)
|
149 |
+
num_cfg = 2 if use_img_guidance else 1
|
150 |
+
latent_size_h, latent_size_w = height//8, width//8
|
151 |
+
|
152 |
+
latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device)
|
153 |
+
latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
|
154 |
+
|
155 |
+
input_img_latents = []
|
156 |
+
if separate_cfg_infer:
|
157 |
+
for temp_pixel_values in input_data['input_pixel_values']:
|
158 |
+
temp_input_latents = []
|
159 |
+
for img in temp_pixel_values:
|
160 |
+
img = self.vae_encode(img.to(self.device), dtype)
|
161 |
+
temp_input_latents.append(img)
|
162 |
+
input_img_latents.append(temp_input_latents)
|
163 |
+
else:
|
164 |
+
for img in input_data['input_pixel_values']:
|
165 |
+
img = self.vae_encode(img.to(self.device), dtype)
|
166 |
+
input_img_latents.append(img)
|
167 |
+
|
168 |
+
model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']),
|
169 |
+
input_img_latents=input_img_latents,
|
170 |
+
input_image_sizes=input_data['input_image_sizes'],
|
171 |
+
attention_mask=self.move_to_device(input_data["attention_mask"]),
|
172 |
+
position_ids=self.move_to_device(input_data["position_ids"]),
|
173 |
+
cfg_scale=guidance_scale,
|
174 |
+
img_cfg_scale=img_guidance_scale,
|
175 |
+
use_img_cfg=use_img_guidance,
|
176 |
+
use_kv_cache=use_kv_cache)
|
177 |
+
|
178 |
+
if separate_cfg_infer:
|
179 |
+
func = self.model.forward_with_separate_cfg
|
180 |
+
else:
|
181 |
+
func = self.model.forward_with_cfg
|
182 |
+
self.model.to(dtype)
|
183 |
+
|
184 |
+
scheduler = OmniGenScheduler(num_steps=num_inference_steps)
|
185 |
+
samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache)
|
186 |
+
samples = samples.chunk((1+num_cfg), dim=0)[0]
|
187 |
+
|
188 |
+
samples = samples.to(torch.float32)
|
189 |
+
if self.vae.config.shift_factor is not None:
|
190 |
+
samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
|
191 |
+
else:
|
192 |
+
samples = samples / self.vae.config.scaling_factor
|
193 |
+
samples = self.vae.decode(samples).sample
|
194 |
+
|
195 |
+
output_samples = (samples * 0.5 + 0.5).clamp(0, 1)*255
|
196 |
+
output_samples = output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
197 |
+
output_images = []
|
198 |
+
for i, sample in enumerate(output_samples):
|
199 |
+
output_images.append(Image.fromarray(sample))
|
200 |
+
|
201 |
+
return output_images
|
OmniGen/processor.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from typing import Dict, List
|
4 |
+
import json
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision import transforms
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
from huggingface_hub import snapshot_download
|
13 |
+
|
14 |
+
|
15 |
+
def crop_arr(pil_image, max_image_size):
|
16 |
+
while min(*pil_image.size) >= 2 * max_image_size:
|
17 |
+
pil_image = pil_image.resize(
|
18 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
19 |
+
)
|
20 |
+
|
21 |
+
if max(*pil_image.size) > max_image_size:
|
22 |
+
scale = max_image_size / max(*pil_image.size)
|
23 |
+
pil_image = pil_image.resize(
|
24 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
25 |
+
)
|
26 |
+
|
27 |
+
arr = np.array(pil_image)
|
28 |
+
crop_y1 = (arr.shape[0] % 16) // 2
|
29 |
+
crop_y2 = arr.shape[0] % 16 - crop_y1
|
30 |
+
|
31 |
+
crop_x1 = (arr.shape[1] % 16) // 2
|
32 |
+
crop_x2 = arr.shape[1] % 16 - crop_x1
|
33 |
+
|
34 |
+
arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
|
35 |
+
return Image.fromarray(arr)
|
36 |
+
|
37 |
+
|
38 |
+
class OmniGenProcessor:
|
39 |
+
def __init__(self,
|
40 |
+
text_tokenizer,
|
41 |
+
max_image_size: int=1024):
|
42 |
+
self.text_tokenizer = text_tokenizer
|
43 |
+
self.max_image_size = max_image_size
|
44 |
+
|
45 |
+
self.image_transform = transforms.Compose([
|
46 |
+
transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
|
47 |
+
transforms.ToTensor(),
|
48 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
49 |
+
])
|
50 |
+
|
51 |
+
self.collator = OmniGenCollator()
|
52 |
+
self.separate_collator = OmniGenSeparateCollator()
|
53 |
+
|
54 |
+
@classmethod
|
55 |
+
def from_pretrained(cls, model_name):
|
56 |
+
if not os.path.exists(model_name):
|
57 |
+
cache_folder = os.getenv('HF_HUB_CACHE')
|
58 |
+
model_name = snapshot_download(repo_id=model_name,
|
59 |
+
cache_dir=cache_folder,
|
60 |
+
allow_patterns="*.json")
|
61 |
+
text_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
62 |
+
|
63 |
+
return cls(text_tokenizer)
|
64 |
+
|
65 |
+
|
66 |
+
def process_image(self, image):
|
67 |
+
image = Image.open(image).convert('RGB')
|
68 |
+
return self.image_transform(image)
|
69 |
+
|
70 |
+
def process_multi_modal_prompt(self, text, input_images):
|
71 |
+
if input_images is None or len(input_images) == 0:
|
72 |
+
model_inputs = self.text_tokenizer(text)
|
73 |
+
return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
|
74 |
+
|
75 |
+
pattern = r"<\|image_\d+\|>"
|
76 |
+
prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
|
77 |
+
|
78 |
+
for i in range(1, len(prompt_chunks)):
|
79 |
+
if prompt_chunks[i][0] == 1:
|
80 |
+
prompt_chunks[i] = prompt_chunks[i][1:]
|
81 |
+
|
82 |
+
image_tags = re.findall(pattern, text)
|
83 |
+
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
|
84 |
+
|
85 |
+
unique_image_ids = sorted(list(set(image_ids)))
|
86 |
+
assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
|
87 |
+
# total images must be the same as the number of image tags
|
88 |
+
assert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
|
89 |
+
|
90 |
+
input_images = [input_images[x-1] for x in image_ids]
|
91 |
+
|
92 |
+
all_input_ids = []
|
93 |
+
img_inx = []
|
94 |
+
idx = 0
|
95 |
+
for i in range(len(prompt_chunks)):
|
96 |
+
all_input_ids.extend(prompt_chunks[i])
|
97 |
+
if i != len(prompt_chunks) -1:
|
98 |
+
start_inx = len(all_input_ids)
|
99 |
+
size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
|
100 |
+
img_inx.append([start_inx, start_inx+size])
|
101 |
+
all_input_ids.extend([0]*size)
|
102 |
+
|
103 |
+
return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
|
104 |
+
|
105 |
+
|
106 |
+
def add_prefix_instruction(self, prompt):
|
107 |
+
user_prompt = '<|user|>\n'
|
108 |
+
generation_prompt = 'Generate an image according to the following instructions\n'
|
109 |
+
assistant_prompt = '<|assistant|>\n<|diffusion|>'
|
110 |
+
prompt_suffix = "<|end|>\n"
|
111 |
+
prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
|
112 |
+
return prompt
|
113 |
+
|
114 |
+
|
115 |
+
def __call__(self,
|
116 |
+
instructions: List[str],
|
117 |
+
input_images: List[List[str]] = None,
|
118 |
+
height: int = 1024,
|
119 |
+
width: int = 1024,
|
120 |
+
negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
|
121 |
+
use_img_cfg: bool = True,
|
122 |
+
separate_cfg_input: bool = False,
|
123 |
+
) -> Dict:
|
124 |
+
|
125 |
+
if input_images is None:
|
126 |
+
use_img_cfg = False
|
127 |
+
if isinstance(instructions, str):
|
128 |
+
instructions = [instructions]
|
129 |
+
input_images = [input_images]
|
130 |
+
|
131 |
+
input_data = []
|
132 |
+
for i in range(len(instructions)):
|
133 |
+
cur_instruction = instructions[i]
|
134 |
+
cur_input_images = None if input_images is None else input_images[i]
|
135 |
+
cur_instruction = self.add_prefix_instruction(cur_instruction)
|
136 |
+
if cur_input_images is not None and len(cur_input_images) > 0:
|
137 |
+
cur_input_images = [self.process_image(x) for x in cur_input_images]
|
138 |
+
else:
|
139 |
+
cur_input_images = None
|
140 |
+
assert "<img><|image_1|></img>" not in cur_instruction
|
141 |
+
|
142 |
+
mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
|
143 |
+
|
144 |
+
|
145 |
+
neg_mllm_input, img_cfg_mllm_input = None, None
|
146 |
+
neg_instruction = self.add_prefix_instruction(negative_prompt)
|
147 |
+
neg_mllm_input = self.process_multi_modal_prompt(neg_instruction, None)
|
148 |
+
if use_img_cfg:
|
149 |
+
if cur_input_images is not None and len(cur_input_images) >= 1:
|
150 |
+
img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
|
151 |
+
img_cfg_mllm_input = self.process_multi_modal_prompt(self.add_prefix_instruction(" ".join(img_cfg_prompt)), cur_input_images)
|
152 |
+
else:
|
153 |
+
img_cfg_mllm_input = neg_instruction
|
154 |
+
|
155 |
+
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
|
156 |
+
|
157 |
+
if separate_cfg_input:
|
158 |
+
return self.separate_collator(input_data)
|
159 |
+
return self.collator(input_data)
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
class OmniGenCollator:
|
165 |
+
def __init__(self, pad_token_id=2, hidden_size=3072):
|
166 |
+
self.pad_token_id = pad_token_id
|
167 |
+
self.hidden_size = hidden_size
|
168 |
+
|
169 |
+
def create_position(self, attention_mask, num_tokens_for_output_images):
|
170 |
+
position_ids = []
|
171 |
+
text_length = attention_mask.size(-1)
|
172 |
+
img_length = max(num_tokens_for_output_images)
|
173 |
+
for mask in attention_mask:
|
174 |
+
temp_l = torch.sum(mask)
|
175 |
+
temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
|
176 |
+
position_ids.append(temp_position)
|
177 |
+
return torch.LongTensor(position_ids)
|
178 |
+
|
179 |
+
def create_mask(self, attention_mask, num_tokens_for_output_images):
|
180 |
+
extended_mask = []
|
181 |
+
padding_images = []
|
182 |
+
text_length = attention_mask.size(-1)
|
183 |
+
img_length = max(num_tokens_for_output_images)
|
184 |
+
seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
|
185 |
+
inx = 0
|
186 |
+
for mask in attention_mask:
|
187 |
+
temp_l = torch.sum(mask)
|
188 |
+
pad_l = text_length - temp_l
|
189 |
+
|
190 |
+
temp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))
|
191 |
+
|
192 |
+
image_mask = torch.zeros(size=(temp_l+1, img_length))
|
193 |
+
temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
|
194 |
+
|
195 |
+
image_mask = torch.ones(size=(img_length, temp_l+img_length+1))
|
196 |
+
temp_mask = torch.cat([temp_mask, image_mask], dim=0)
|
197 |
+
|
198 |
+
if pad_l > 0:
|
199 |
+
pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))
|
200 |
+
temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
|
201 |
+
|
202 |
+
pad_mask = torch.ones(size=(pad_l, seq_len))
|
203 |
+
temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
|
204 |
+
|
205 |
+
true_img_length = num_tokens_for_output_images[inx]
|
206 |
+
pad_img_length = img_length - true_img_length
|
207 |
+
if pad_img_length > 0:
|
208 |
+
temp_mask[:, -pad_img_length:] = 0
|
209 |
+
temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
|
210 |
+
else:
|
211 |
+
temp_padding_imgs = None
|
212 |
+
|
213 |
+
extended_mask.append(temp_mask.unsqueeze(0))
|
214 |
+
padding_images.append(temp_padding_imgs)
|
215 |
+
inx += 1
|
216 |
+
return torch.cat(extended_mask, dim=0), padding_images
|
217 |
+
|
218 |
+
def adjust_attention_for_input_images(self, attention_mask, image_sizes):
|
219 |
+
for b_inx in image_sizes.keys():
|
220 |
+
for start_inx, end_inx in image_sizes[b_inx]:
|
221 |
+
attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
|
222 |
+
|
223 |
+
return attention_mask
|
224 |
+
|
225 |
+
def pad_input_ids(self, input_ids, image_sizes):
|
226 |
+
max_l = max([len(x) for x in input_ids])
|
227 |
+
padded_ids = []
|
228 |
+
attention_mask = []
|
229 |
+
new_image_sizes = []
|
230 |
+
|
231 |
+
for i in range(len(input_ids)):
|
232 |
+
temp_ids = input_ids[i]
|
233 |
+
temp_l = len(temp_ids)
|
234 |
+
pad_l = max_l - temp_l
|
235 |
+
if pad_l == 0:
|
236 |
+
attention_mask.append([1]*max_l)
|
237 |
+
padded_ids.append(temp_ids)
|
238 |
+
else:
|
239 |
+
attention_mask.append([0]*pad_l+[1]*temp_l)
|
240 |
+
padded_ids.append([self.pad_token_id]*pad_l+temp_ids)
|
241 |
+
|
242 |
+
if i in image_sizes:
|
243 |
+
new_inx = []
|
244 |
+
for old_inx in image_sizes[i]:
|
245 |
+
new_inx.append([x+pad_l for x in old_inx])
|
246 |
+
image_sizes[i] = new_inx
|
247 |
+
|
248 |
+
return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
|
249 |
+
|
250 |
+
|
251 |
+
def process_mllm_input(self, mllm_inputs, target_img_size):
|
252 |
+
num_tokens_for_output_images = []
|
253 |
+
for img_size in target_img_size:
|
254 |
+
num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)
|
255 |
+
|
256 |
+
pixel_values, image_sizes = [], {}
|
257 |
+
b_inx = 0
|
258 |
+
for x in mllm_inputs:
|
259 |
+
if x['pixel_values'] is not None:
|
260 |
+
pixel_values.extend(x['pixel_values'])
|
261 |
+
for size in x['image_sizes']:
|
262 |
+
if b_inx not in image_sizes:
|
263 |
+
image_sizes[b_inx] = [size]
|
264 |
+
else:
|
265 |
+
image_sizes[b_inx].append(size)
|
266 |
+
b_inx += 1
|
267 |
+
pixel_values = [x.unsqueeze(0) for x in pixel_values]
|
268 |
+
|
269 |
+
|
270 |
+
input_ids = [x['input_ids'] for x in mllm_inputs]
|
271 |
+
padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
|
272 |
+
position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
|
273 |
+
attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
|
274 |
+
attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
|
275 |
+
|
276 |
+
return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
|
277 |
+
|
278 |
+
|
279 |
+
def __call__(self, features):
|
280 |
+
mllm_inputs = [f[0] for f in features]
|
281 |
+
cfg_mllm_inputs = [f[1] for f in features]
|
282 |
+
img_cfg_mllm_input = [f[2] for f in features]
|
283 |
+
target_img_size = [f[3] for f in features]
|
284 |
+
|
285 |
+
|
286 |
+
if img_cfg_mllm_input[0] is not None:
|
287 |
+
mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
|
288 |
+
target_img_size = target_img_size + target_img_size + target_img_size
|
289 |
+
else:
|
290 |
+
mllm_inputs = mllm_inputs + cfg_mllm_inputs
|
291 |
+
target_img_size = target_img_size + target_img_size
|
292 |
+
|
293 |
+
|
294 |
+
all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
|
295 |
+
|
296 |
+
data = {"input_ids": all_padded_input_ids,
|
297 |
+
"attention_mask": all_attention_mask,
|
298 |
+
"position_ids": all_position_ids,
|
299 |
+
"input_pixel_values": all_pixel_values,
|
300 |
+
"input_image_sizes": all_image_sizes,
|
301 |
+
"padding_images": all_padding_images,
|
302 |
+
}
|
303 |
+
return data
|
304 |
+
|
305 |
+
|
306 |
+
class OmniGenSeparateCollator(OmniGenCollator):
|
307 |
+
def __call__(self, features):
|
308 |
+
mllm_inputs = [f[0] for f in features]
|
309 |
+
cfg_mllm_inputs = [f[1] for f in features]
|
310 |
+
img_cfg_mllm_input = [f[2] for f in features]
|
311 |
+
target_img_size = [f[3] for f in features]
|
312 |
+
|
313 |
+
|
314 |
+
all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
|
315 |
+
|
316 |
+
|
317 |
+
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
|
318 |
+
all_padded_input_ids.append(padded_input_ids)
|
319 |
+
all_attention_mask.append(attention_mask)
|
320 |
+
all_position_ids.append(position_ids)
|
321 |
+
all_pixel_values.append(pixel_values)
|
322 |
+
all_image_sizes.append(image_sizes)
|
323 |
+
all_padding_images.append(padding_images)
|
324 |
+
|
325 |
+
if cfg_mllm_inputs[0] is not None:
|
326 |
+
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)
|
327 |
+
all_padded_input_ids.append(padded_input_ids)
|
328 |
+
all_attention_mask.append(attention_mask)
|
329 |
+
all_position_ids.append(position_ids)
|
330 |
+
all_pixel_values.append(pixel_values)
|
331 |
+
all_image_sizes.append(image_sizes)
|
332 |
+
all_padding_images.append(padding_images)
|
333 |
+
if img_cfg_mllm_input[0] is not None:
|
334 |
+
padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)
|
335 |
+
all_padded_input_ids.append(padded_input_ids)
|
336 |
+
all_attention_mask.append(attention_mask)
|
337 |
+
all_position_ids.append(position_ids)
|
338 |
+
all_pixel_values.append(pixel_values)
|
339 |
+
all_image_sizes.append(image_sizes)
|
340 |
+
all_padding_images.append(padding_images)
|
341 |
+
|
342 |
+
data = {"input_ids": all_padded_input_ids,
|
343 |
+
"attention_mask": all_attention_mask,
|
344 |
+
"position_ids": all_position_ids,
|
345 |
+
"input_pixel_values": all_pixel_values,
|
346 |
+
"input_image_sizes": all_image_sizes,
|
347 |
+
"padding_images": all_padding_images,
|
348 |
+
}
|
349 |
+
return data
|
OmniGen/scheduler.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
from transformers.cache_utils import Cache, DynamicCache, OffloadedCache
|
4 |
+
|
5 |
+
class OmniGenScheduler:
|
6 |
+
def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
|
7 |
+
self.num_steps = num_steps
|
8 |
+
self.time_shift = time_shifting_factor
|
9 |
+
|
10 |
+
t = torch.linspace(0, 1, num_steps+1)
|
11 |
+
t = t / (t + time_shifting_factor - time_shifting_factor * t)
|
12 |
+
self.sigma = t
|
13 |
+
|
14 |
+
def crop_kv_cache(self, past_key_values, num_tokens_for_img):
|
15 |
+
crop_past_key_values = ()
|
16 |
+
for layer_idx in range(len(past_key_values)):
|
17 |
+
key_states, value_states = past_key_values[layer_idx][:2]
|
18 |
+
crop_past_key_values += ((key_states[..., :-(num_tokens_for_img+1), :], value_states[..., :-(num_tokens_for_img+1), :], ),)
|
19 |
+
return crop_past_key_values
|
20 |
+
# return DynamicCache.from_legacy_cache(crop_past_key_values)
|
21 |
+
|
22 |
+
def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
|
23 |
+
if isinstance(position_ids, list):
|
24 |
+
for i in range(len(position_ids)):
|
25 |
+
position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
|
26 |
+
else:
|
27 |
+
position_ids = position_ids[:, -(num_tokens_for_img+1):]
|
28 |
+
return position_ids
|
29 |
+
|
30 |
+
def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
|
31 |
+
if isinstance(attention_mask, list):
|
32 |
+
return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
|
33 |
+
return attention_mask[..., -(num_tokens_for_img+1):, :]
|
34 |
+
|
35 |
+
def __call__(self, z, func, model_kwargs, use_kv_cache: bool=True):
|
36 |
+
past_key_values = None
|
37 |
+
for i in tqdm(range(self.num_steps)):
|
38 |
+
timesteps = torch.zeros(size=(len(z), )).to(z.device) + self.sigma[i]
|
39 |
+
pred, temp_past_key_values = func(z, timesteps, past_key_values=past_key_values, **model_kwargs)
|
40 |
+
sigma_next = self.sigma[i+1]
|
41 |
+
sigma = self.sigma[i]
|
42 |
+
z = z + (sigma_next - sigma) * pred
|
43 |
+
if i == 0 and use_kv_cache:
|
44 |
+
num_tokens_for_img = z.size(-1)*z.size(-2) // 4
|
45 |
+
if isinstance(temp_past_key_values, list):
|
46 |
+
past_key_values = [self.crop_kv_cache(x, num_tokens_for_img) for x in temp_past_key_values]
|
47 |
+
model_kwargs['input_ids'] = [None] * len(temp_past_key_values)
|
48 |
+
else:
|
49 |
+
past_key_values = self.crop_kv_cache(temp_past_key_values, num_tokens_for_img)
|
50 |
+
model_kwargs['input_ids'] = None
|
51 |
+
|
52 |
+
model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
|
53 |
+
model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
|
54 |
+
return z
|
55 |
+
|
OmniGen/train.py
ADDED
File without changes
|
OmniGen/transformer.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.utils.checkpoint
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
|
11 |
+
from transformers.modeling_outputs import (
|
12 |
+
BaseModelOutputWithPast,
|
13 |
+
CausalLMOutputWithPast,
|
14 |
+
SequenceClassifierOutputWithPast,
|
15 |
+
TokenClassifierOutput,
|
16 |
+
)
|
17 |
+
from transformers.modeling_utils import PreTrainedModel
|
18 |
+
from transformers import Phi3Config, Phi3Model
|
19 |
+
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache, OffloadedCache
|
20 |
+
from transformers.utils import logging
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class Phi3Transformer(Phi3Model):
|
26 |
+
"""
|
27 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
|
28 |
+
We only modified the attention mask
|
29 |
+
Args:
|
30 |
+
config: Phi3Config
|
31 |
+
"""
|
32 |
+
|
33 |
+
def forward(
|
34 |
+
self,
|
35 |
+
input_ids: torch.LongTensor = None,
|
36 |
+
attention_mask: Optional[torch.Tensor] = None,
|
37 |
+
position_ids: Optional[torch.LongTensor] = None,
|
38 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
39 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
40 |
+
use_cache: Optional[bool] = None,
|
41 |
+
output_attentions: Optional[bool] = None,
|
42 |
+
output_hidden_states: Optional[bool] = None,
|
43 |
+
return_dict: Optional[bool] = None,
|
44 |
+
cache_position: Optional[torch.LongTensor] = None,
|
45 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
46 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
47 |
+
output_hidden_states = (
|
48 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
49 |
+
)
|
50 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
51 |
+
|
52 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
53 |
+
|
54 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
55 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
56 |
+
|
57 |
+
if self.gradient_checkpointing and self.training:
|
58 |
+
if use_cache:
|
59 |
+
logger.warning_once(
|
60 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
61 |
+
)
|
62 |
+
use_cache = False
|
63 |
+
|
64 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
65 |
+
return_legacy_cache = False
|
66 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
67 |
+
return_legacy_cache = True
|
68 |
+
if past_key_values is None:
|
69 |
+
past_key_values = DynamicCache()
|
70 |
+
else:
|
71 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
72 |
+
logger.warning_once(
|
73 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
74 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
75 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
76 |
+
)
|
77 |
+
|
78 |
+
if inputs_embeds is None:
|
79 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
80 |
+
|
81 |
+
if cache_position is None:
|
82 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
83 |
+
cache_position = torch.arange(
|
84 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
85 |
+
)
|
86 |
+
if position_ids is None:
|
87 |
+
position_ids = cache_position.unsqueeze(0)
|
88 |
+
|
89 |
+
if attention_mask is not None and attention_mask.dim() == 3:
|
90 |
+
dtype = inputs_embeds.dtype
|
91 |
+
min_dtype = torch.finfo(dtype).min
|
92 |
+
attention_mask = (1 - attention_mask) * min_dtype
|
93 |
+
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
|
94 |
+
else:
|
95 |
+
raise
|
96 |
+
# causal_mask = self._update_causal_mask(
|
97 |
+
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
98 |
+
# )
|
99 |
+
|
100 |
+
hidden_states = inputs_embeds
|
101 |
+
|
102 |
+
# decoder layers
|
103 |
+
all_hidden_states = () if output_hidden_states else None
|
104 |
+
all_self_attns = () if output_attentions else None
|
105 |
+
next_decoder_cache = None
|
106 |
+
|
107 |
+
for decoder_layer in self.layers:
|
108 |
+
if output_hidden_states:
|
109 |
+
all_hidden_states += (hidden_states,)
|
110 |
+
|
111 |
+
if self.gradient_checkpointing and self.training:
|
112 |
+
layer_outputs = self._gradient_checkpointing_func(
|
113 |
+
decoder_layer.__call__,
|
114 |
+
hidden_states,
|
115 |
+
attention_mask,
|
116 |
+
position_ids,
|
117 |
+
past_key_values,
|
118 |
+
output_attentions,
|
119 |
+
use_cache,
|
120 |
+
cache_position,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
layer_outputs = decoder_layer(
|
124 |
+
hidden_states,
|
125 |
+
attention_mask=attention_mask,
|
126 |
+
position_ids=position_ids,
|
127 |
+
past_key_value=past_key_values,
|
128 |
+
output_attentions=output_attentions,
|
129 |
+
use_cache=use_cache,
|
130 |
+
cache_position=cache_position,
|
131 |
+
)
|
132 |
+
|
133 |
+
hidden_states = layer_outputs[0]
|
134 |
+
|
135 |
+
if use_cache:
|
136 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
137 |
+
|
138 |
+
if output_attentions:
|
139 |
+
all_self_attns += (layer_outputs[1],)
|
140 |
+
|
141 |
+
hidden_states = self.norm(hidden_states)
|
142 |
+
|
143 |
+
# add hidden states from the last decoder layer
|
144 |
+
if output_hidden_states:
|
145 |
+
all_hidden_states += (hidden_states,)
|
146 |
+
|
147 |
+
next_cache = next_decoder_cache if use_cache else None
|
148 |
+
if return_legacy_cache:
|
149 |
+
next_cache = next_cache.to_legacy_cache()
|
150 |
+
|
151 |
+
if not return_dict:
|
152 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
153 |
+
return BaseModelOutputWithPast(
|
154 |
+
last_hidden_state=hidden_states,
|
155 |
+
past_key_values=next_cache,
|
156 |
+
hidden_states=all_hidden_states,
|
157 |
+
attentions=all_self_attns,
|
158 |
+
)
|
159 |
+
|
app.py
CHANGED
@@ -1,154 +1,68 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
import
|
4 |
|
5 |
-
|
6 |
-
from diffusers import DiffusionPipeline
|
7 |
-
import torch
|
8 |
|
9 |
-
|
10 |
-
model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
|
11 |
|
12 |
-
|
13 |
-
torch_dtype = torch.float16
|
14 |
-
else:
|
15 |
-
torch_dtype = torch.float32
|
16 |
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
# @spaces.GPU #[uncomment to use ZeroGPU]
|
25 |
-
def infer(
|
26 |
-
prompt,
|
27 |
-
negative_prompt,
|
28 |
-
seed,
|
29 |
-
randomize_seed,
|
30 |
-
width,
|
31 |
-
height,
|
32 |
-
guidance_scale,
|
33 |
-
num_inference_steps,
|
34 |
-
progress=gr.Progress(track_tqdm=True),
|
35 |
-
):
|
36 |
-
if randomize_seed:
|
37 |
-
seed = random.randint(0, MAX_SEED)
|
38 |
-
|
39 |
-
generator = torch.Generator().manual_seed(seed)
|
40 |
-
|
41 |
-
image = pipe(
|
42 |
-
prompt=prompt,
|
43 |
-
negative_prompt=negative_prompt,
|
44 |
-
guidance_scale=guidance_scale,
|
45 |
-
num_inference_steps=num_inference_steps,
|
46 |
-
width=width,
|
47 |
height=height,
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
)
|
91 |
-
|
92 |
-
seed = gr.Slider(
|
93 |
-
label="Seed",
|
94 |
-
minimum=0,
|
95 |
-
maximum=MAX_SEED,
|
96 |
-
step=1,
|
97 |
-
value=0,
|
98 |
-
)
|
99 |
-
|
100 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
101 |
-
|
102 |
-
with gr.Row():
|
103 |
-
width = gr.Slider(
|
104 |
-
label="Width",
|
105 |
-
minimum=256,
|
106 |
-
maximum=MAX_IMAGE_SIZE,
|
107 |
-
step=32,
|
108 |
-
value=1024, # Replace with defaults that work for your model
|
109 |
-
)
|
110 |
-
|
111 |
-
height = gr.Slider(
|
112 |
-
label="Height",
|
113 |
-
minimum=256,
|
114 |
-
maximum=MAX_IMAGE_SIZE,
|
115 |
-
step=32,
|
116 |
-
value=1024, # Replace with defaults that work for your model
|
117 |
-
)
|
118 |
-
|
119 |
-
with gr.Row():
|
120 |
-
guidance_scale = gr.Slider(
|
121 |
-
label="Guidance scale",
|
122 |
-
minimum=0.0,
|
123 |
-
maximum=10.0,
|
124 |
-
step=0.1,
|
125 |
-
value=0.0, # Replace with defaults that work for your model
|
126 |
-
)
|
127 |
-
|
128 |
-
num_inference_steps = gr.Slider(
|
129 |
-
label="Number of inference steps",
|
130 |
-
minimum=1,
|
131 |
-
maximum=50,
|
132 |
-
step=1,
|
133 |
-
value=2, # Replace with defaults that work for your model
|
134 |
-
)
|
135 |
-
|
136 |
-
gr.Examples(examples=examples, inputs=[prompt])
|
137 |
-
gr.on(
|
138 |
-
triggers=[run_button.click, prompt.submit],
|
139 |
-
fn=infer,
|
140 |
-
inputs=[
|
141 |
-
prompt,
|
142 |
-
negative_prompt,
|
143 |
-
seed,
|
144 |
-
randomize_seed,
|
145 |
-
width,
|
146 |
-
height,
|
147 |
-
guidance_scale,
|
148 |
-
num_inference_steps,
|
149 |
-
],
|
150 |
-
outputs=[result, seed],
|
151 |
)
|
152 |
|
153 |
-
|
154 |
-
|
|
|
1 |
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import os
|
4 |
|
5 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '7'
|
|
|
|
|
6 |
|
7 |
+
from OmniGen import OmniGenPipeline
|
|
|
8 |
|
9 |
+
pipe = OmniGenPipeline.from_pretrained("shitao/tmp-preview")
|
|
|
|
|
|
|
10 |
|
11 |
+
# 示例处理函数:生成图像
|
12 |
+
def generate_image(text, img1, img2, img3, height, width, guidance_scale):
|
13 |
+
input_images = [img1, img2, img3]
|
14 |
+
# 去除 None
|
15 |
+
input_images = [img for img in input_images if img is not None]
|
16 |
+
if len(input_images) == 0:
|
17 |
+
input_images = None
|
18 |
|
19 |
+
output = pipe(
|
20 |
+
prompt=text,
|
21 |
+
input_images=input_images,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
height=height,
|
23 |
+
width=width,
|
24 |
+
guidance_scale=guidance_scale,
|
25 |
+
img_guidance_scale=1.6,
|
26 |
+
separate_cfg_infer=True,
|
27 |
+
use_kv_cache=False
|
28 |
+
)
|
29 |
+
img = output[0]
|
30 |
+
return img
|
31 |
+
|
32 |
+
# Gradio 接口
|
33 |
+
with gr.Blocks() as demo:
|
34 |
+
gr.Markdown("## Text + Multiple Images to Image Generator")
|
35 |
+
|
36 |
+
with gr.Row():
|
37 |
+
with gr.Column():
|
38 |
+
# 文本输入框
|
39 |
+
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your prompt here...")
|
40 |
+
|
41 |
+
# 图片上传框
|
42 |
+
image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
|
43 |
+
image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
|
44 |
+
image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")
|
45 |
+
|
46 |
+
# 高度和宽度滑块
|
47 |
+
height_input = gr.Slider(label="Height", minimum=256, maximum=2048, value=1024, step=16)
|
48 |
+
width_input = gr.Slider(label="Width", minimum=256, maximum=2048, value=1024, step=16)
|
49 |
+
|
50 |
+
# 引导尺度输入
|
51 |
+
guidance_scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
|
52 |
+
|
53 |
+
# 生成按钮
|
54 |
+
generate_button = gr.Button("Generate Image")
|
55 |
+
|
56 |
+
with gr.Column():
|
57 |
+
# 输出图像框
|
58 |
+
output_image = gr.Image(label="Output Image")
|
59 |
+
|
60 |
+
# 按钮点击事件
|
61 |
+
generate_button.click(
|
62 |
+
generate_image,
|
63 |
+
inputs=[prompt_input, image_input_1, image_input_2, image_input_3, height_input, width_input, guidance_scale_input],
|
64 |
+
outputs=output_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
)
|
66 |
|
67 |
+
# 启动应用
|
68 |
+
demo.launch()
|
edit.png
ADDED
imgs/.DS_Store
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d65165279105ca6773180500688df4bdc69a2c7b771752f0a46ef120b7fd8ec3
|
3 |
+
size 6148
|
imgs/test_cases/liuyifei.png
ADDED
imgs/test_cases/taylor.png
ADDED
imgs/test_cases/trump.png
ADDED
imgs/test_cases/turing.png
ADDED
inference.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
setup.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
with open("README.md", mode="r", encoding="utf-8") as readme_file:
|
4 |
+
readme = readme_file.read()
|
5 |
+
|
6 |
+
setup(
|
7 |
+
name='OmniGen',
|
8 |
+
version='1.0.0',
|
9 |
+
description='OmniGen',
|
10 |
+
long_description=readme,
|
11 |
+
long_description_content_type="text/markdown",
|
12 |
+
author_email='2906698981@qq.com',
|
13 |
+
url='https://github.com/VectorSpaceLab/OmniGen',
|
14 |
+
packages=find_packages(),
|
15 |
+
include_package_data=True,
|
16 |
+
install_requires=[
|
17 |
+
'torch>=1.6.0',
|
18 |
+
'transformers>=4.41.0',
|
19 |
+
'datasets',
|
20 |
+
'accelerate>=0.20.1',
|
21 |
+
'diffusers>=0.30.3'
|
22 |
+
],
|
23 |
+
)
|