File size: 2,530 Bytes
a9bd37f
 
63f2d6e
 
a9bd37f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import cv2
import gradio as gr
import mim
mim.install('mmcv-full==1.5.0')
from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
                         vis_pose_result, process_mmdet_results)
from mmdet.apis import inference_detector, init_detector
import mediapy

pose_config = 'configs/topdown_heatmap_hrnet_w48_coco_256x192.py'
pose_checkpoint = 'hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
det_config = 'configs/faster_rcnn_r50_fpn_1x_coco.py'
det_checkpoint = 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

# initialize pose model
pose_model = init_pose_model(pose_config, pose_checkpoint, device='cpu')
# initialize detector
det_model = init_detector(det_config, det_checkpoint, device='cpu')


max_num_frames=120
def predict(video_path):
    cap = cv2.VideoCapture(video_path)
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    fps = cap.get(cv2.CAP_PROP_FPS)

    preds_all = []

    # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    # out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
    # writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height))
    frames = []

    for _ in range(max_num_frames):
        ok, frame = cap.read()
        if not ok:
            break
        rgb_frame = frame[:,:,::-1]
        mmdet_results = inference_detector(det_model, rgb_frame)
        person_results = process_mmdet_results(mmdet_results, cat_id=1)
        pose_results, returned_outputs = inference_top_down_pose_model(
          pose_model,
          rgb_frame,
          person_results,
          bbox_thr=0.3,
          format='xyxy',
          dataset=pose_model.cfg.data.test.type)
        vis_result = vis_pose_result(
          pose_model,
          rgb_frame,
          pose_results,
          dataset=pose_model.cfg.data.test.type,
          show=False)
        frames.append(vis_result)
    cap.release()
    # writer.release()
    mediapy.write_video("out.mp4", frames, fps=fps)
    return "out.mp4"

title = "Pose Estimation video"
description = ""
article = ""

example_list = ['examples/000001_mpiinew_test.mp4']

# Create the Gradio demo
demo = gr.Interface(fn=predict,
                    inputs=gr.Video(label='Input Video'), 
                    outputs=gr.Video(label='Result'), 
                    examples=example_list, 
                    title=title,
                    description=description,
                    article=article)

# Launch the demo!
demo.queue().launch(show_api=False)