import cv2 import gradio as gr import mim mim.install('mmcv-full==1.5.0') from mmpose.apis import (inference_top_down_pose_model, init_pose_model, vis_pose_result, process_mmdet_results) from mmdet.apis import inference_detector, init_detector import mediapy pose_config = 'configs/topdown_heatmap_hrnet_w48_coco_256x192.py' pose_checkpoint = 'hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth' det_config = 'configs/faster_rcnn_r50_fpn_1x_coco.py' det_checkpoint = 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' # initialize pose model pose_model = init_pose_model(pose_config, pose_checkpoint, device='cpu') # initialize detector det_model = init_detector(det_config, det_checkpoint, device='cpu') max_num_frames=120 def predict(video_path): cap = cv2.VideoCapture(video_path) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) fps = cap.get(cv2.CAP_PROP_FPS) preds_all = [] # fourcc = cv2.VideoWriter_fourcc(*'mp4v') # out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) # writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height)) frames = [] for _ in range(max_num_frames): ok, frame = cap.read() if not ok: break rgb_frame = frame[:,:,::-1] mmdet_results = inference_detector(det_model, rgb_frame) person_results = process_mmdet_results(mmdet_results, cat_id=1) pose_results, returned_outputs = inference_top_down_pose_model( pose_model, rgb_frame, person_results, bbox_thr=0.3, format='xyxy', dataset=pose_model.cfg.data.test.type) vis_result = vis_pose_result( pose_model, rgb_frame, pose_results, dataset=pose_model.cfg.data.test.type, show=False) frames.append(vis_result) cap.release() # writer.release() mediapy.write_video("out.mp4", frames, fps=fps) return "out.mp4" title = "Pose Estimation video" description = "" article = "" example_list = ['examples/000001_mpiinew_test.mp4'] # Create the Gradio demo demo = gr.Interface(fn=predict, inputs=gr.Video(label='Input Video'), outputs=gr.Video(label='Result'), examples=example_list, title=title, description=description, article=article) # Launch the demo! demo.queue().launch(show_api=False)