File size: 2,742 Bytes
61c2d32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.

from typing import List

import cv2
import numpy as np
import torch
from densepose import add_densepose_config
from densepose.structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
from densepose.vis.extractor import DensePoseOutputsExtractor, DensePoseResultExtractor
from detectron2.config import get_cfg
from detectron2.engine.defaults import DefaultPredictor
from detectron2.structures.instances import Instances
from PIL import Image


class DensePose4Gradio:
    def __init__(self, cfg, model) -> None:
        cfg = self.setup_config(cfg, model, [])
        self.predictor = DefaultPredictor(cfg)

    def setup_config(
        self, config_fpath: str, model_fpath: str, opts: List[str]
    ):
        cfg = get_cfg()
        add_densepose_config(cfg)
        cfg.merge_from_file(config_fpath)
        if opts:
            cfg.merge_from_list(opts)
        cfg.MODEL.WEIGHTS = model_fpath
        cfg.freeze()
        return cfg

    def execute(self, image: Image.Image):
        img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        with torch.no_grad():
            outputs = self.predictor(img)["instances"]
            return self.execute_on_outputs(img, outputs)

    def execute_on_outputs(self, image: np.ndarray, outputs: Instances):
        result = {}
        if outputs.has("scores"):
            result["scores"] = outputs.get("scores").cpu()
        if outputs.has("pred_boxes"):
            result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
            if outputs.has("pred_densepose"):
                if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
                    extractor = DensePoseResultExtractor()
                elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
                    extractor = DensePoseOutputsExtractor()
                result["pred_densepose"] = extractor(outputs)[0]

                H, W, _ = image.shape

                i = result['pred_densepose'][0].labels.cpu().numpy()
                i_scale = (i.astype(np.float32) * 255 / 24).astype(np.uint8)
                i_color = cv2.applyColorMap(i_scale, cv2.COLORMAP_PARULA)
                i_color = cv2.cvtColor(i_color, cv2.COLOR_RGB2BGR)
                i_color[i == 0] = [0, 0, 0]

                box = result["pred_boxes_XYXY"][0]
                box[2] = box[2] - box[0]
                box[3] = box[3] - box[1]
                x, y, w, h = [int(v) for v in box]

                bg = np.zeros((H, W, 3))
                bg[y:y + h, x:x + w, :] = i_color

                bg_img = Image.fromarray(np.uint8(bg), "RGB")

                return bg_img