rlawjdghek's picture
prep (#1)
61c2d32 verified
raw
history blame
No virus
2.74 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
from typing import List
import cv2
import numpy as np
import torch
from densepose import add_densepose_config
from densepose.structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
from densepose.vis.extractor import DensePoseOutputsExtractor, DensePoseResultExtractor
from detectron2.config import get_cfg
from detectron2.engine.defaults import DefaultPredictor
from detectron2.structures.instances import Instances
from PIL import Image
class DensePose4Gradio:
def __init__(self, cfg, model) -> None:
cfg = self.setup_config(cfg, model, [])
self.predictor = DefaultPredictor(cfg)
def setup_config(
self, config_fpath: str, model_fpath: str, opts: List[str]
):
cfg = get_cfg()
add_densepose_config(cfg)
cfg.merge_from_file(config_fpath)
if opts:
cfg.merge_from_list(opts)
cfg.MODEL.WEIGHTS = model_fpath
cfg.freeze()
return cfg
def execute(self, image: Image.Image):
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
with torch.no_grad():
outputs = self.predictor(img)["instances"]
return self.execute_on_outputs(img, outputs)
def execute_on_outputs(self, image: np.ndarray, outputs: Instances):
result = {}
if outputs.has("scores"):
result["scores"] = outputs.get("scores").cpu()
if outputs.has("pred_boxes"):
result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
if outputs.has("pred_densepose"):
if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
extractor = DensePoseResultExtractor()
elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
extractor = DensePoseOutputsExtractor()
result["pred_densepose"] = extractor(outputs)[0]
H, W, _ = image.shape
i = result['pred_densepose'][0].labels.cpu().numpy()
i_scale = (i.astype(np.float32) * 255 / 24).astype(np.uint8)
i_color = cv2.applyColorMap(i_scale, cv2.COLORMAP_PARULA)
i_color = cv2.cvtColor(i_color, cv2.COLOR_RGB2BGR)
i_color[i == 0] = [0, 0, 0]
box = result["pred_boxes_XYXY"][0]
box[2] = box[2] - box[0]
box[3] = box[3] - box[1]
x, y, w, h = [int(v) for v in box]
bg = np.zeros((H, W, 3))
bg[y:y + h, x:x + w, :] = i_color
bg_img = Image.fromarray(np.uint8(bg), "RGB")
return bg_img