jhj0517 commited on
Commit
77caed5
1 Parent(s): 699c0d5

refactor test stage

Browse files
Files changed (1) hide show
  1. test_stage.py +8 -6
test_stage.py CHANGED
@@ -1,8 +1,6 @@
1
  import argparse
2
  from omegaconf import OmegaConf
3
- import torch
4
- from pprint import pprint
5
-
6
  from musepose_inference import MusePoseInference
7
 
8
 
@@ -20,7 +18,8 @@ def parse_args():
20
  parser.add_argument("--steps", type=int, default=20, help="DDIM sampling steps")
21
  parser.add_argument("--fps", type=int)
22
  parser.add_argument("--weight_dtype", type=str, default="fp16")
23
- parser.add_argument("--output_dir", type=str, default="./output")
 
24
 
25
  parser.add_argument("--skip", type=int, default=1, help="frame sample rate = (skip+1)")
26
  args = parser.parse_args()
@@ -32,12 +31,15 @@ def main():
32
  args = parse_args()
33
  config = OmegaConf.load(args.config)
34
 
35
- musepose_infer = MusePoseInference(config=config, output_dir=args.output_dir)
 
 
 
36
 
37
  ref_image_path = list(config["test_cases"].keys())[0]
38
  pose_video_path = config["test_cases"][ref_image_path][0]
39
 
40
- output_file_path = musepose_infer.infer_musepose(
41
  ref_image_path=ref_image_path,
42
  pose_video_path=pose_video_path,
43
  weight_dtype=args.weight_dtype,
 
1
  import argparse
2
  from omegaconf import OmegaConf
3
+ import os
 
 
4
  from musepose_inference import MusePoseInference
5
 
6
 
 
18
  parser.add_argument("--steps", type=int, default=20, help="DDIM sampling steps")
19
  parser.add_argument("--fps", type=int)
20
  parser.add_argument("--weight_dtype", type=str, default="fp16")
21
+ parser.add_argument('--model_dir', type=str, default=os.path.join("pretrained_weights"), help='Pretrained models directory for MusePose')
22
+ parser.add_argument('--output_dir', type=str, default=os.path.join("assets", "videos"), help='Output directory for the result')
23
 
24
  parser.add_argument("--skip", type=int, default=1, help="frame sample rate = (skip+1)")
25
  args = parser.parse_args()
 
31
  args = parse_args()
32
  config = OmegaConf.load(args.config)
33
 
34
+ musepose_infer = MusePoseInference(
35
+ model_dir=args.model_dir,
36
+ output_dir=args.output_dir
37
+ )
38
 
39
  ref_image_path = list(config["test_cases"].keys())[0]
40
  pose_video_path = config["test_cases"][ref_image_path][0]
41
 
42
+ output_file_path, output_demo_file_path = musepose_infer.infer_musepose(
43
  ref_image_path=ref_image_path,
44
  pose_video_path=pose_video_path,
45
  weight_dtype=args.weight_dtype,