Spaces:
Paused
Paused
fix generate_frames target poses
Browse files
main.py
CHANGED
|
@@ -288,7 +288,7 @@ def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
|
|
| 288 |
return in_img, in_pose, train_imgs, train_poses
|
| 289 |
|
| 290 |
|
| 291 |
-
def prepare_inputs_inference(in_img, in_vid, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app=False):
|
| 292 |
progress=gr.Progress(track_tqdm=True)
|
| 293 |
|
| 294 |
print("prepare_inputs_inference")
|
|
@@ -310,7 +310,8 @@ def prepare_inputs_inference(in_img, in_vid, frames, fps, dwpose, rembg_session,
|
|
| 310 |
print("vid: ", in_vid, fps)
|
| 311 |
|
| 312 |
progress_bar = tqdm(range(len(frames)), initial=0, desc="Frames")
|
| 313 |
-
|
|
|
|
| 314 |
target_poses_coords = []
|
| 315 |
max_left = max_top = 999999
|
| 316 |
max_right = max_bottom = 0
|
|
@@ -333,6 +334,7 @@ def prepare_inputs_inference(in_img, in_vid, frames, fps, dwpose, rembg_session,
|
|
| 333 |
target_poses_coords.append(json.dumps(coords))
|
| 334 |
progress_bar.update(1)
|
| 335 |
|
|
|
|
| 336 |
bbox = tpose.getbbox()
|
| 337 |
left, top, right, bottom = bbox
|
| 338 |
max_left = min(max_left, left)
|
|
@@ -498,7 +500,6 @@ def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pc
|
|
| 498 |
logging_dir = 'outputs/logging'
|
| 499 |
print('start train')
|
| 500 |
|
| 501 |
-
|
| 502 |
progress=gr.Progress(track_tqdm=True)
|
| 503 |
|
| 504 |
accelerator = Accelerator(
|
|
@@ -1159,7 +1160,8 @@ def run_inference_impl(images, video_path, frames, train_steps=100, inference_st
|
|
| 1159 |
frames = [img[0] for img in frames]
|
| 1160 |
|
| 1161 |
in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, video_path, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
|
| 1162 |
-
|
|
|
|
| 1163 |
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1164 |
#urls = save_temp_imgs(results)
|
| 1165 |
|
|
@@ -1189,19 +1191,20 @@ def generate_frame(images, target_poses, train_steps=100, inference_steps=10, mo
|
|
| 1189 |
is_app=True
|
| 1190 |
|
| 1191 |
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
|
| 1192 |
-
|
| 1193 |
if not os.path.exists(modelId+".pt"):
|
| 1194 |
run_train(images, train_steps, modelId, bg_remove, resize_inputs)
|
| 1195 |
-
|
| 1196 |
images = [img[0] for img in images]
|
| 1197 |
in_img = images[0]
|
| 1198 |
in_pose, _ = get_pose(in_img, dwpose, "in_pose.png")
|
| 1199 |
|
| 1200 |
print(target_poses)
|
| 1201 |
target_poses = json.loads(target_poses)
|
| 1202 |
-
|
| 1203 |
-
|
| 1204 |
-
target_poses[
|
|
|
|
| 1205 |
|
| 1206 |
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1207 |
#urls = save_temp_imgs(results)
|
|
|
|
| 288 |
return in_img, in_pose, train_imgs, train_poses
|
| 289 |
|
| 290 |
|
| 291 |
+
def prepare_inputs_inference(in_img, in_vid, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app=False, target_poses=None):
|
| 292 |
progress=gr.Progress(track_tqdm=True)
|
| 293 |
|
| 294 |
print("prepare_inputs_inference")
|
|
|
|
| 310 |
print("vid: ", in_vid, fps)
|
| 311 |
|
| 312 |
progress_bar = tqdm(range(len(frames)), initial=0, desc="Frames")
|
| 313 |
+
if not target_poses:
|
| 314 |
+
target_poses = []
|
| 315 |
target_poses_coords = []
|
| 316 |
max_left = max_top = 999999
|
| 317 |
max_right = max_bottom = 0
|
|
|
|
| 334 |
target_poses_coords.append(json.dumps(coords))
|
| 335 |
progress_bar.update(1)
|
| 336 |
|
| 337 |
+
for tpose in target_poses:
|
| 338 |
bbox = tpose.getbbox()
|
| 339 |
left, top, right, bottom = bbox
|
| 340 |
max_left = min(max_left, left)
|
|
|
|
| 500 |
logging_dir = 'outputs/logging'
|
| 501 |
print('start train')
|
| 502 |
|
|
|
|
| 503 |
progress=gr.Progress(track_tqdm=True)
|
| 504 |
|
| 505 |
accelerator = Accelerator(
|
|
|
|
| 1160 |
frames = [img[0] for img in frames]
|
| 1161 |
|
| 1162 |
in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, video_path, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
|
| 1163 |
+
target_poses[0].save('inf_pose.png')
|
| 1164 |
+
|
| 1165 |
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1166 |
#urls = save_temp_imgs(results)
|
| 1167 |
|
|
|
|
| 1191 |
is_app=True
|
| 1192 |
|
| 1193 |
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
|
| 1194 |
+
|
| 1195 |
if not os.path.exists(modelId+".pt"):
|
| 1196 |
run_train(images, train_steps, modelId, bg_remove, resize_inputs)
|
| 1197 |
+
|
| 1198 |
images = [img[0] for img in images]
|
| 1199 |
in_img = images[0]
|
| 1200 |
in_pose, _ = get_pose(in_img, dwpose, "in_pose.png")
|
| 1201 |
|
| 1202 |
print(target_poses)
|
| 1203 |
target_poses = json.loads(target_poses)
|
| 1204 |
+
target_poses = [Image.fromarray(draw_openpose(pose, height=img_height, width=img_width, include_hands=True, include_face=False)) for pose in target_poses]
|
| 1205 |
+
|
| 1206 |
+
in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, None, [], 12, dwpose, rembg_session, bg_remove, resize_inputs, is_app, target_poses)
|
| 1207 |
+
target_poses[0].save('gen_pose.png')
|
| 1208 |
|
| 1209 |
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1210 |
#urls = save_temp_imgs(results)
|