acmyu commited on
Commit
26ea696
·
1 Parent(s): 9aab429

fix generate_frames target poses

Browse files
Files changed (1) hide show
  1. main.py +12 -9
main.py CHANGED
@@ -288,7 +288,7 @@ def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
288
  return in_img, in_pose, train_imgs, train_poses
289
 
290
 
291
- def prepare_inputs_inference(in_img, in_vid, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app=False):
292
  progress=gr.Progress(track_tqdm=True)
293
 
294
  print("prepare_inputs_inference")
@@ -310,7 +310,8 @@ def prepare_inputs_inference(in_img, in_vid, frames, fps, dwpose, rembg_session,
310
  print("vid: ", in_vid, fps)
311
 
312
  progress_bar = tqdm(range(len(frames)), initial=0, desc="Frames")
313
- target_poses = []
 
314
  target_poses_coords = []
315
  max_left = max_top = 999999
316
  max_right = max_bottom = 0
@@ -333,6 +334,7 @@ def prepare_inputs_inference(in_img, in_vid, frames, fps, dwpose, rembg_session,
333
  target_poses_coords.append(json.dumps(coords))
334
  progress_bar.update(1)
335
 
 
336
  bbox = tpose.getbbox()
337
  left, top, right, bottom = bbox
338
  max_left = min(max_left, left)
@@ -498,7 +500,6 @@ def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pc
498
  logging_dir = 'outputs/logging'
499
  print('start train')
500
 
501
-
502
  progress=gr.Progress(track_tqdm=True)
503
 
504
  accelerator = Accelerator(
@@ -1159,7 +1160,8 @@ def run_inference_impl(images, video_path, frames, train_steps=100, inference_st
1159
  frames = [img[0] for img in frames]
1160
 
1161
  in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, video_path, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
1162
-
 
1163
  results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1164
  #urls = save_temp_imgs(results)
1165
 
@@ -1189,19 +1191,20 @@ def generate_frame(images, target_poses, train_steps=100, inference_steps=10, mo
1189
  is_app=True
1190
 
1191
  dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1192
-
1193
  if not os.path.exists(modelId+".pt"):
1194
  run_train(images, train_steps, modelId, bg_remove, resize_inputs)
1195
-
1196
  images = [img[0] for img in images]
1197
  in_img = images[0]
1198
  in_pose, _ = get_pose(in_img, dwpose, "in_pose.png")
1199
 
1200
  print(target_poses)
1201
  target_poses = json.loads(target_poses)
1202
- width, height = in_img.size
1203
- target_poses = [Image.fromarray(draw_openpose(pose, height=height, width=width, include_hands=True, include_face=False)) for pose in target_poses]
1204
- target_poses[0].save('pose.png')
 
1205
 
1206
  results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1207
  #urls = save_temp_imgs(results)
 
288
  return in_img, in_pose, train_imgs, train_poses
289
 
290
 
291
+ def prepare_inputs_inference(in_img, in_vid, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app=False, target_poses=None):
292
  progress=gr.Progress(track_tqdm=True)
293
 
294
  print("prepare_inputs_inference")
 
310
  print("vid: ", in_vid, fps)
311
 
312
  progress_bar = tqdm(range(len(frames)), initial=0, desc="Frames")
313
+ if not target_poses:
314
+ target_poses = []
315
  target_poses_coords = []
316
  max_left = max_top = 999999
317
  max_right = max_bottom = 0
 
334
  target_poses_coords.append(json.dumps(coords))
335
  progress_bar.update(1)
336
 
337
+ for tpose in target_poses:
338
  bbox = tpose.getbbox()
339
  left, top, right, bottom = bbox
340
  max_left = min(max_left, left)
 
500
  logging_dir = 'outputs/logging'
501
  print('start train')
502
 
 
503
  progress=gr.Progress(track_tqdm=True)
504
 
505
  accelerator = Accelerator(
 
1160
  frames = [img[0] for img in frames]
1161
 
1162
  in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, video_path, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
1163
+ target_poses[0].save('inf_pose.png')
1164
+
1165
  results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1166
  #urls = save_temp_imgs(results)
1167
 
 
1191
  is_app=True
1192
 
1193
  dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1194
+
1195
  if not os.path.exists(modelId+".pt"):
1196
  run_train(images, train_steps, modelId, bg_remove, resize_inputs)
1197
+
1198
  images = [img[0] for img in images]
1199
  in_img = images[0]
1200
  in_pose, _ = get_pose(in_img, dwpose, "in_pose.png")
1201
 
1202
  print(target_poses)
1203
  target_poses = json.loads(target_poses)
1204
+ target_poses = [Image.fromarray(draw_openpose(pose, height=img_height, width=img_width, include_hands=True, include_face=False)) for pose in target_poses]
1205
+
1206
+ in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, None, [], 12, dwpose, rembg_session, bg_remove, resize_inputs, is_app, target_poses)
1207
+ target_poses[0].save('gen_pose.png')
1208
 
1209
  results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1210
  #urls = save_temp_imgs(results)