File size: 2,112 Bytes
f43c9b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be3490d
f43c9b7
 
 
 
483d460
f43c9b7
 
 
 
2e55dd7
 
f43c9b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import textwrap
import subprocess
import shutil
import os
from pathlib import Path

import torch
import gradio as gr
from huggingface_hub import hf_hub_download


REPO_ID = "kbrodt/sketch2pose"
API_TOKEN = os.environ["sketch2pose"]
ASSET_DIR = Path("./assets")

filename = "models_smplx_v1_1.zip"
smpl_path = hf_hub_download(
    repo_id=REPO_ID,
    repo_type="model",
    filename=filename,
    use_auth_token=API_TOKEN,
    cache_dir=ASSET_DIR,
)
if not (ASSET_DIR / filename).is_file():
    shutil.copy(smpl_path, ASSET_DIR)

subprocess.run("bash ./scripts/download.sh".split())
subprocess.run("bash ./scripts/prepare.sh".split())


SAVE_DIR = "output"
CMD = textwrap.dedent("""
    python src/pose.py
        --save-path {}
        --img-path {}
""")


def main():
    save_dir = Path(SAVE_DIR)
    save_dir.mkdir(parents=True, exist_ok=True)

    def pose(img_path, use_cos=True, use_angle_transf=True, use_contacts=False, use_natural=True):
        if use_cos == False:
            use_angle_transf = False

        cmd = CMD.format(save_dir, img_path)
        if use_cos:
            cmd = cmd + " --use-cos"
        if use_angle_transf:
            cmd = cmd + " --use-angle-transf"
        if use_contacts:
            cmd = cmd + " --use-contacts"
        if use_natural:
            cmd = cmd + " --use-natural"

        out_dir = (save_dir / Path(img_path).name).with_suffix("")
        mesh_path = out_dir / "us.glb"

        if not mesh_path.is_file():
            subprocess.call(cmd.split())

        return str(mesh_path)

    examples = []
    for img_path in Path("./data/images").glob("*"):
        examples.append([str(img_path), True, True, False, True])
        break

    demo = gr.Interface(
        fn=pose,
        inputs=[
            gr.Image(type="filepath"),
            gr.Checkbox(value=True),
            gr.Checkbox(value=True),
            gr.Checkbox(value=False, interactive=torch.cuda.is_available()),
            gr.Checkbox(value=True),
        ],
        outputs=gr.Model3D(),
        examples=examples,
    )

    demo.launch()


if __name__ == "__main__":
    main()