File size: 3,057 Bytes
9bdd97c
 
 
 
 
 
 
 
f57e36f
9bdd97c
 
 
 
 
 
 
161720b
9bdd97c
 
 
161720b
9bdd97c
161720b
9bdd97c
 
f57e36f
161720b
f57e36f
161720b
 
f57e36f
 
161720b
f57e36f
 
 
161720b
f57e36f
 
 
 
 
 
 
 
 
161720b
f57e36f
 
 
9bdd97c
 
 
161720b
 
 
 
9bdd97c
 
 
 
 
 
 
76221dd
9bdd97c
 
 
 
 
 
ed7463d
 
 
161720b
ed7463d
 
 
7479a3a
 
161720b
7479a3a
 
 
161720b
 
 
7479a3a
161720b
7479a3a
 
161720b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python

from __future__ import annotations

import functools
import os
import pathlib
import sys
import tarfile

import gradio as gr
import huggingface_hub
import PIL.Image
import torch
import torchvision

sys.path.insert(0, "bizarre-pose-estimator")

from _util.twodee_v0 import I as ImageWrapper

DESCRIPTION = "# [ShuhongChen/bizarre-pose-estimator (tagger)](https://github.com/ShuhongChen/bizarre-pose-estimator)"

MODEL_REPO = "public-data/bizarre-pose-estimator-models"


def load_sample_image_paths() -> list[pathlib.Path]:
    image_dir = pathlib.Path("images")
    if not image_dir.exists():
        dataset_repo = "hysts/sample-images-TADNE"
        path = huggingface_hub.hf_hub_download(dataset_repo, "images.tar.gz", repo_type="dataset")
        with tarfile.open(path) as f:
            f.extractall()
    return sorted(image_dir.glob("*"))


def load_model(device: torch.device) -> torch.nn.Module:
    path = huggingface_hub.hf_hub_download(MODEL_REPO, "tagger.pth")
    state_dict = torch.load(path)
    model = torchvision.models.resnet50(num_classes=1062)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model


def load_labels() -> list[str]:
    label_path = huggingface_hub.hf_hub_download(MODEL_REPO, "tags.txt")
    with open(label_path) as f:
        labels = [line.strip() for line in f.readlines()]
    return labels


@torch.inference_mode()
def predict(
    image: PIL.Image.Image, score_threshold: float, device: torch.device, model: torch.nn.Module, labels: list[str]
) -> dict[str, float]:
    data = ImageWrapper(image).resize_square(256).alpha_bg(c="w").convert("RGB").tensor()
    data = data.to(device).unsqueeze(0)

    preds = model(data)[0]
    preds = torch.sigmoid(preds)
    preds = preds.cpu().numpy().astype(float)

    res = dict()
    for prob, label in zip(preds.tolist(), labels):
        if prob < score_threshold:
            continue
        res[label] = prob
    return res


image_paths = load_sample_image_paths()
examples = [[path.as_posix(), 0.5] for path in image_paths]

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(device)
labels = load_labels()

fn = functools.partial(predict, device=device, model=model, labels=labels)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Input", type="pil")
            threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.05, value=0.5)
            run_button = gr.Button("Run")
        with gr.Column():
            result = gr.Label(label="Output")

    inputs = [image, threshold]
    gr.Examples(
        examples=examples,
        inputs=inputs,
        outputs=result,
        fn=fn,
        cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
    )
    run_button.click(
        fn=fn,
        inputs=inputs,
        outputs=result,
        api_name="predict",
    )

if __name__ == "__main__":
    demo.queue(max_size=15).launch()