File size: 3,254 Bytes
80fecf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from rtmlib import YOLOX, RTMPose, draw_bbox, draw_skeleton
import functools
from typing import Callable
from pathlib import Path
import gradio as gr
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as T


TITLE = 'Human Parsing'

def get_palette(num_cls):
    """ Returns the color map for visualizing the segmentation mask.
    Args:
        num_cls: Number of classes
    Returns:
        The color map
    """

    n = num_cls
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
            palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
            palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
            i += 1
            lab >>= 3
    return palette

@torch.inference_mode()
def predict(image: PIL.Image.Image, model, transform: Callable,
            device: torch.device,palette) -> np.ndarray:
    img_show = np.array(image.copy())
    bboxes = model[1](np. array(image))
    img_show = draw_bbox(img_show, bboxes)
    keypoints,scores = model[2](np. array(image),bboxes=bboxes)
    img_show = draw_skeleton(img_show,keypoints,scores)

    data = transform(image)
    data = data.unsqueeze(0).to(device)
    out = model[0](data)
    out =  F.interpolate(out, [image.size[1],image.size[0]], mode="bilinear")
    output = out[0].permute(1,2,0)
    parsing = torch.argmax(output,dim=2).cpu().numpy()

    output_im = Image.fromarray(np.asarray(parsing, dtype=np.uint8))
    image = Image.fromarray(np.asarray(img_show, dtype=np.uint8))
    output_im.putpalette(palette)
    output_im = output_im.convert('RGB')
    # output_im.save('output.png')

    res = Image.blend(image.convert('RGB'), output_im, 0.5)
    return output_im, res


def load_parsing_model():
    model = torch.jit.load(Path("models/humanparsing_572_384.pt"))
    model.eval()
    return model


def main():
    device = torch.device('cpu')
    model_ls =[]
    model = load_parsing_model()

    transform = T.Compose([
        T.Resize((572, 384), interpolation=PIL.Image.NEAREST),
        T.ToTensor(),
        T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    palette = get_palette(20)
    det_model = YOLOX('models/det.onnx',model_input_size=(640,640),backend='onnxruntime', device='cpu')
    pose_model = RTMPose('models/pose.onnx', model_input_size=(192, 256),to_openpose=False, backend='onnxruntime', device='cpu')

    model_ls.append(model)
    model_ls.append(det_model)
    model_ls.append(pose_model)

    func = functools.partial(predict,
                             model=model_ls,
                             transform=transform,
                             device=device,palette=palette)


    gr.Interface(
        fn=func,
        inputs=gr.Image(label='Input', type='pil'),
        outputs=[
            gr.Image(label='Predicted Labels', type='pil'),
            gr.Image(label='Masked', type='pil'),
        ],
        title=TITLE,
    ).queue().launch(show_api=False)

if __name__ == "__main__":
    main()