Omer Karisman commited on
Commit
8daac10
1 Parent(s): 6ddb532
Files changed (1) hide show
  1. app.py +185 -24
app.py CHANGED
@@ -9,14 +9,81 @@ import torch
9
  torch.jit.script = lambda f: f
10
  ####
11
 
12
- from omni_zero_spaces import OmniZeroCouple
 
 
 
13
 
14
- omni_zero = OmniZeroCouple(
15
- base_model="frankjoshua/albedobaseXL_v13",
16
- device="cuda",
17
- )
18
 
19
- omni_zero.generate = spaces.GPU(omni_zero.generate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  @spaces.GPU()
22
  def generate(
@@ -40,26 +107,120 @@ def generate(
40
  mask_guidance_end=1.0,
41
  progress=gr.Progress(track_tqdm=True)
42
  ):
43
- images = omni_zero.generate(
44
- seed=seed,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  prompt=prompt,
46
- negative_prompt=negative_prompt,
47
  guidance_scale=guidance_scale,
48
- number_of_images=number_of_images,
49
- number_of_steps=number_of_steps,
50
- base_image=base_image,
51
- base_image_strength=base_image_strength,
52
- style_image=style_image,
53
- style_image_strength=style_image_strength,
54
- identity_image_1=identity_image_1,
55
- identity_image_strength_1=identity_image_strength_1,
56
- identity_image_2=identity_image_2,
57
- identity_image_strength_2=identity_image_strength_2,
58
- depth_image=depth_image,
59
- depth_image_strength=depth_image_strength,
60
- mask_guidance_start=mask_guidance_start,
61
- mask_guidance_end=mask_guidance_end,
62
- )
 
63
 
64
  return images
65
 
 
9
  torch.jit.script = lambda f: f
10
  ####
11
 
12
+ from huggingface_hub import snapshot_download
13
+ from diffusers import DPMSolverMultistepScheduler
14
+ from diffusers.models import ControlNetModel
15
+ from diffusers.image_processor import IPAdapterMaskProcessor
16
 
17
+ from transformers import CLIPVisionModelWithProjection
 
 
 
18
 
19
+ from pipeline import OmniZeroPipeline
20
+ from insightface.app import FaceAnalysis
21
+ from controlnet_aux import ZoeDetector
22
+ from utils import draw_kps, load_and_resize_image, align_images
23
+
24
+ import cv2
25
+ import numpy as np
26
+
27
+ import PIL
28
+
29
+ def patch_onnx_runtime(
30
+ self,
31
+ inter_op_num_threads: int = 16,
32
+ intra_op_num_threads: int = 16,
33
+ omp_num_threads: int = 16,
34
+ ):
35
+ import os
36
+ import onnxruntime as ort
37
+
38
+ os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
39
+
40
+ _default_session_options = ort.capi._pybind_state.get_default_session_options()
41
+
42
+ def get_default_session_options_new():
43
+ _default_session_options.inter_op_num_threads = inter_op_num_threads
44
+ _default_session_options.intra_op_num_threads = intra_op_num_threads
45
+ return _default_session_options
46
+
47
+ ort.capi._pybind_state.get_default_session_options = get_default_session_options_new
48
+
49
+
50
+ base_model = "frankjoshua/albedobaseXL_v13"
51
+
52
+ patch_onnx_runtime()
53
+
54
+ snapshot_download("okaris/antelopev2", local_dir="./models/antelopev2")
55
+ face_analysis = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
56
+ face_analysis.prepare(ctx_id=0, det_size=(640, 640))
57
+
58
+ dtype = torch.float16
59
+
60
+ ip_adapter_plus_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
61
+ "h94/IP-Adapter",
62
+ subfolder="models/image_encoder",
63
+ torch_dtype=dtype,
64
+ ).to("cuda")
65
+
66
+ zoedepthnet_path = "okaris/zoe-depth-controlnet-xl"
67
+ zoedepthnet = ControlNetModel.from_pretrained(zoedepthnet_path,torch_dtype=dtype).to("cuda")
68
+
69
+ identitiynet_path = "okaris/face-controlnet-xl"
70
+ identitynet = ControlNetModel.from_pretrained(identitiynet_path, torch_dtype=dtype).to("cuda")
71
+
72
+ zoe_depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
73
+ ip_adapter_mask_processor = IPAdapterMaskProcessor()
74
+
75
+ pipeline = OmniZeroPipeline.from_pretrained(
76
+ base_model,
77
+ controlnet=[identitynet, identitynet, zoedepthnet],
78
+ torch_dtype=dtype,
79
+ image_encoder=ip_adapter_plus_image_encoder,
80
+ ).to("cuda")
81
+
82
+ config = pipeline.scheduler.config
83
+ config["timestep_spacing"] = "trailing"
84
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++", final_sigmas_type="zero")
85
+
86
+ pipeline.load_ip_adapter(["okaris/ip-adapter-instantid", "okaris/ip-adapter-instantid", "h94/IP-Adapter"], subfolder=[None, None, "sdxl_models"], weight_name=["ip-adapter-instantid.bin", "ip-adapter-instantid.bin", "ip-adapter-plus_sdxl_vit-h.safetensors"])
87
 
88
  @spaces.GPU()
89
  def generate(
 
107
  mask_guidance_end=1.0,
108
  progress=gr.Progress(track_tqdm=True)
109
  ):
110
+ resolution = 1024
111
+
112
+ if base_image is not None:
113
+ base_image = load_and_resize_image(base_image, resolution, resolution)
114
+
115
+ if depth_image is None:
116
+ depth_image = zoe_depth_detector(base_image, detect_resolution=resolution, image_resolution=resolution)
117
+ else:
118
+ depth_image = load_and_resize_image(depth_image, resolution, resolution)
119
+
120
+ base_image, depth_image = align_images(base_image, depth_image)
121
+
122
+ if style_image is not None:
123
+ style_image = load_and_resize_image(style_image, resolution, resolution)
124
+ else:
125
+ raise ValueError("You must provide a style image")
126
+
127
+ if identity_image_1 is not None:
128
+ identity_image_1 = load_and_resize_image(identity_image_1, resolution, resolution)
129
+ else:
130
+ raise ValueError("You must provide an identity image")
131
+
132
+ if identity_image_2 is not None:
133
+ identity_image_2 = load_and_resize_image(identity_image_2, resolution, resolution)
134
+ else:
135
+ raise ValueError("You must provide an identity image 2")
136
+
137
+ height, width = base_image.size
138
+
139
+ face_info_1 = face_analysis.get(cv2.cvtColor(np.array(identity_image_1), cv2.COLOR_RGB2BGR))
140
+ for i, face in enumerate(face_info_1):
141
+ print(f"Face 1 -{i}: Age: {face['age']}, Gender: {face['gender']}")
142
+ face_info_1 = sorted(face_info_1, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
143
+ face_emb_1 = torch.tensor(face_info_1['embedding']).to("cuda", dtype=dtype)
144
+
145
+ face_info_2 = face_analysis.get(cv2.cvtColor(np.array(identity_image_2), cv2.COLOR_RGB2BGR))
146
+ for i, face in enumerate(face_info_2):
147
+ print(f"Face 2 -{i}: Age: {face['age']}, Gender: {face['gender']}")
148
+ face_info_2 = sorted(face_info_2, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
149
+ face_emb_2 = torch.tensor(face_info_2['embedding']).to("cuda", dtype=dtype)
150
+
151
+ zero = np.zeros((width, height, 3), dtype=np.uint8)
152
+ # face_kps_identity_image_1 = draw_kps(zero, face_info_1['kps'])
153
+ # face_kps_identity_image_2 = draw_kps(zero, face_info_2['kps'])
154
+
155
+ face_info_img2img = face_analysis.get(cv2.cvtColor(np.array(base_image), cv2.COLOR_RGB2BGR))
156
+ faces_info_img2img = sorted(face_info_img2img, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])
157
+ face_info_a = faces_info_img2img[-1]
158
+ face_info_b = faces_info_img2img[-2]
159
+ # face_emb_a = torch.tensor(face_info_a['embedding']).to("cuda", dtype=dtype)
160
+ # face_emb_b = torch.tensor(face_info_b['embedding']).to("cuda", dtype=dtype)
161
+ face_kps_identity_image_a = draw_kps(zero, face_info_a['kps'])
162
+ face_kps_identity_image_b = draw_kps(zero, face_info_b['kps'])
163
+
164
+ general_mask = PIL.Image.fromarray(np.ones((width, height, 3), dtype=np.uint8))
165
+
166
+ control_mask_1 = zero.copy()
167
+ x1, y1, x2, y2 = face_info_a["bbox"]
168
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
169
+ control_mask_1[y1:y2, x1:x2] = 255
170
+ control_mask_1 = PIL.Image.fromarray(control_mask_1.astype(np.uint8))
171
+
172
+ control_mask_2 = zero.copy()
173
+ x1, y1, x2, y2 = face_info_b["bbox"]
174
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
175
+ control_mask_2[y1:y2, x1:x2] = 255
176
+ control_mask_2 = PIL.Image.fromarray(control_mask_2.astype(np.uint8))
177
+
178
+ controlnet_masks = [control_mask_1, control_mask_2, general_mask]
179
+ ip_adapter_images = [face_emb_1, face_emb_2, style_image, ]
180
+
181
+ masks = ip_adapter_mask_processor.preprocess([control_mask_1, control_mask_2, general_mask], height=height, width=width)
182
+ ip_adapter_masks = [mask.unsqueeze(0) for mask in masks]
183
+
184
+ inpaint_mask = torch.logical_or(torch.tensor(np.array(control_mask_1)), torch.tensor(np.array(control_mask_2))).float()
185
+ inpaint_mask = PIL.Image.fromarray((inpaint_mask.numpy() * 255).astype(np.uint8)).convert("RGB")
186
+
187
+ new_ip_adapter_masks = []
188
+ for ip_img, mask in zip(ip_adapter_images, controlnet_masks):
189
+ if isinstance(ip_img, list):
190
+ num_images = len(ip_img)
191
+ mask = mask.repeat(1, num_images, 1, 1)
192
+
193
+ new_ip_adapter_masks.append(mask)
194
+
195
+ generator = torch.Generator(device="cpu").manual_seed(seed)
196
+
197
+ pipeline.set_ip_adapter_scale([identity_image_strength_1, identity_image_strength_2,
198
+ {
199
+ "down": { "block_2": [0.0, 0.0] }, #Composition
200
+ "up": { "block_0": [0.0, style_image_strength, 0.0] } #Style
201
+ }
202
+ ])
203
+
204
+ images = pipeline(
205
  prompt=prompt,
206
+ negative_prompt=negative_prompt,
207
  guidance_scale=guidance_scale,
208
+ num_inference_steps=number_of_steps,
209
+ num_images_per_prompt=number_of_images,
210
+ ip_adapter_image=ip_adapter_images,
211
+ cross_attention_kwargs={"ip_adapter_masks": ip_adapter_masks},
212
+ image=base_image,
213
+ mask_image=inpaint_mask,
214
+ i2i_mask_guidance_start=mask_guidance_start,
215
+ i2i_mask_guidance_end=mask_guidance_end,
216
+ control_image=[face_kps_identity_image_a, face_kps_identity_image_b, depth_image],
217
+ control_mask=controlnet_masks,
218
+ identity_control_indices=[(0,0), (1,1)],
219
+ controlnet_conditioning_scale=[identity_image_strength_1, identity_image_strength_2, depth_image_strength],
220
+ strength=1-base_image_strength,
221
+ generator=generator,
222
+ seed=seed,
223
+ ).images
224
 
225
  return images
226