|
|
|
|
|
|
|
import mmpose |
|
from mmpose.apis import MMPoseInferencer |
|
|
|
|
|
from ultralytics import YOLO |
|
import torch |
|
|
|
|
|
import gradio as gr |
|
import moviepy.editor as moviepy |
|
|
|
|
|
|
|
import os |
|
import glob |
|
import uuid |
|
|
|
|
|
import numpy as np |
|
import cv2 |
|
|
|
print(torch.__version__) |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
print("[INFO]: Imported modules!") |
|
human = MMPoseInferencer("human") |
|
hand = MMPoseInferencer("hand") |
|
human3d = MMPoseInferencer(pose3d="human3d") |
|
track_model = YOLO('yolov8n.pt') |
|
|
|
|
|
print("[INFO]: Downloaded models!") |
|
|
|
def check_extension(video): |
|
split_tup = os.path.splitext(video) |
|
|
|
|
|
file_name = split_tup[0] |
|
file_extension = split_tup[1] |
|
|
|
if file_extension != ".mp4": |
|
print("Converting to mp4") |
|
clip = moviepy.VideoFileClip(video) |
|
|
|
video = file_name+".mp4" |
|
clip.write_videofile(video) |
|
|
|
return video |
|
|
|
|
|
def tracking(video, model, boxes=True): |
|
print("[INFO] Is cuda available? ", torch.cuda.is_available()) |
|
print(device) |
|
|
|
print("[INFO] Loading model...") |
|
|
|
|
|
|
|
print("[INFO] Starting tracking!") |
|
|
|
annotated_frame = model(video, boxes=boxes, device=device) |
|
|
|
return annotated_frame |
|
|
|
def show_tracking(video_content): |
|
|
|
|
|
video = cv2.VideoCapture(video_content) |
|
|
|
|
|
video_track = tracking(video_content, track_model.track) |
|
|
|
|
|
|
|
out_file = "track.mp4" |
|
print("[INFO]: TRACK", out_file) |
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
fps = video.get(cv2.CAP_PROP_FPS) |
|
height, width, _ = video_track[0][0].orig_img.shape |
|
size = (width,height) |
|
|
|
out_track = cv2.VideoWriter(out_file, fourcc, fps, size) |
|
|
|
|
|
for frame_track in video_track: |
|
result_track = frame_track[0].plot() |
|
out_track.write(result_track) |
|
|
|
print("[INFO] Done with frames") |
|
|
|
|
|
out_track.release() |
|
|
|
video.release() |
|
cv2.destroyAllWindows() |
|
|
|
return out_file |
|
|
|
|
|
def pose3d(video): |
|
video = check_extension(video) |
|
print(device) |
|
|
|
|
|
|
|
add_dir = str(uuid.uuid4()) |
|
vis_out_dir = os.path.join("/".join(video.split("/")[:-1]), add_dir) |
|
os.makedirs(vis_out_dir) |
|
|
|
result_generator = human3d(video, |
|
vis_out_dir = vis_out_dir, |
|
thickness=2, |
|
return_vis=True, |
|
rebase_keypoint_height=True, |
|
device=device) |
|
|
|
result = [result for result in result_generator] |
|
|
|
out_file = glob.glob(os.path.join(vis_out_dir, "*.mp4")) |
|
|
|
return "".join(out_file) |
|
|
|
|
|
def pose2d(video, kpt_threshold): |
|
video = check_extension(video) |
|
print(device) |
|
|
|
|
|
add_dir = str(uuid.uuid4()) |
|
vis_out_dir = os.path.join("/".join(video.split("/")[:-1]), add_dir) |
|
os.makedirs(vis_out_dir) |
|
|
|
result_generator = human(video, |
|
vis_out_dir = vis_out_dir, |
|
return_vis=True, |
|
thickness=2, |
|
rebase_keypoint_height=True, |
|
kpt_thr=kpt_threshold, |
|
device=device |
|
) |
|
|
|
result = [result for result in result_generator] |
|
|
|
out_file = glob.glob(os.path.join(vis_out_dir, "*.mp4")) |
|
|
|
return "".join(out_file) |
|
|
|
|
|
def pose2dhand(video, kpt_threshold): |
|
video = check_extension(video) |
|
print(device) |
|
|
|
|
|
|
|
add_dir = str(uuid.uuid4()) |
|
vis_out_dir = os.path.join("/".join(video.split("/")[:-1]), add_dir) |
|
os.makedirs(vis_out_dir) |
|
|
|
result_generator = hand(video, |
|
vis_out_dir = vis_out_dir, |
|
return_vis=True, |
|
thickness=2, |
|
rebase_keypoint_height=True, |
|
kpt_thr=kpt_threshold, |
|
device=device) |
|
|
|
result = [result for result in result_generator] |
|
|
|
out_file = glob.glob(os.path.join(vis_out_dir, "*.mp4")) |
|
|
|
return "".join(out_file) |
|
|
|
def run_UI(): |
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
with gr.Tab("Upload video"): |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
video_input = gr.Video(source="upload", type="filepath", height=612) |
|
|
|
file_kpthr = gr.Slider(minimum=0.1, maximum=1, step=20, default=0.3, label='Keypoint threshold') |
|
|
|
submit_pose_file = gr.Button("Make 2d pose estimation", variant="primary") |
|
submit_pose3d_file = gr.Button("Make 3d pose estimation", variant="primary") |
|
submit_hand_file = gr.Button("Make 2d hand estimation", variant="primary") |
|
submit_detect_file = gr.Button("Detect and track objects", variant="primary") |
|
|
|
with gr.Row(): |
|
video_output1 = gr.PlayableVideo(height=512, label = "Estimate human 2d poses", show_label=True) |
|
video_output2 = gr.PlayableVideo(height=512, label = "Estimate human 3d poses", show_label=True) |
|
video_output3 = gr.PlayableVideo(height=512, label = "Estimate human hand poses", show_label=True) |
|
video_output4 = gr.Video(height=512, label = "Detection and tracking", show_label=True, format="mp4") |
|
|
|
with gr.Tab("Record video with webcam"): |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
webcam_input = gr.Video(source="webcam", height=612) |
|
|
|
web_kpthr = gr.Slider(minimum=0.1, maximum=1, step=20, default=0.3, label='Keypoint threshold') |
|
|
|
submit_pose_web = gr.Button("Make 2d pose estimation", variant="primary") |
|
submit_pose3d_web = gr.Button("Make 3d pose estimation", variant="primary") |
|
submit_hand_web = gr.Button("Make 2d hand estimation", variant="primary") |
|
submit_detect_web = gr.Button("Detect and track objects", variant="primary") |
|
with gr.Row(): |
|
webcam_output1 = gr.PlayableVideo(height=512, label = "Estimate human 2d poses", show_label=True) |
|
webcam_output2 = gr.PlayableVideo(height=512, label = "Estimate human 3d poses", show_label=True) |
|
webcam_output3 = gr.PlayableVideo(height=512, label = "Estimate human hand position", show_label=True) |
|
webcam_output4 = gr.Video(height=512, label = "Detection and tracking", show_label=True, format="mp4") |
|
|
|
with gr.Tab("General information"): |
|
gr.Markdown("You can load the keypoints in python in the following way: ") |
|
gr.Code( |
|
value="""def hello_world(): |
|
return "Hello, world!" |
|
|
|
print(hello_world())""", |
|
language="python", |
|
interactive=True, |
|
show_label=False, |
|
) |
|
|
|
gr.Markdown(""" |
|
\n # Information about the models |
|
|
|
\n ## Pose models: All the pose estimation models comes from the library [MMpose](https://github.com/open-mmlab/mmpose). It is a library for human pose estimation that provides pre-trained models for 2D and 3D pose estimation. |
|
|
|
\n ### The 2D pose model is used for estimating the 2D coordinates of human body joints from an image or a video frame. The model uses a convolutional neural network (CNN) to predict the joint locations and their confidence scores. |
|
|
|
\n ### The 2D hand model is a specialized version of the 2D pose model that is designed for hand pose estimation. It uses a similar CNN architecture to the 2D pose model but is trained specifically for detecting the joints in the hand. |
|
|
|
\n ### The 3D pose model is used for estimating the 3D coordinates of human body joints from an image or a video frame. The model uses a combination of 2D pose estimation and depth estimation to infer the 3D joint locations. |
|
|
|
\n ### All of these models are pre-trained on large datasets and can be fine-tuned on custom datasets for specific applications. |
|
|
|
\n ## Ultralight detection and tracking model: The `track()` method in the Ultralight model is used for object tracking in videos. It takes a video file or a camera stream as input and returns the tracked objects in each frame. The method uses the COCO dataset classes for object detection and tracking. |
|
|
|
\n ### The COCO dataset contains 80 classes of objects such as person, car, bicycle, etc. See https://docs.ultralytics.com/datasets/detect/coco/ for all available classes. The `track()` method uses the COCO classes to detect and track the objects in the video frames. |
|
The tracked objects are represented as bounding boxes with labels indicating the class of the object.""") |
|
|
|
|
|
submit_pose_file.click(fn=pose2d, |
|
inputs= [video_input, file_kpthr], |
|
outputs = video_output1) |
|
|
|
submit_pose3d_file.click(fn=pose3d, |
|
inputs= video_input, |
|
outputs = video_output2) |
|
|
|
submit_hand_file.click(fn=pose2dhand, |
|
inputs= [video_input, file_kpthr], |
|
outputs = video_output3) |
|
|
|
submit_detect_file.click(fn=show_tracking, |
|
inputs= video_input, |
|
outputs = video_output4) |
|
|
|
|
|
submit_pose_web.click(fn=pose2d, |
|
inputs = [webcam_input, web_kpthr], |
|
outputs = webcam_output1) |
|
|
|
submit_pose3d_web.click(fn=pose3d, |
|
inputs= webcam_input, |
|
outputs = webcam_output2) |
|
|
|
submit_hand_web.click(fn=pose2dhand, |
|
inputs= [webcam_input, web_kpthr], |
|
outputs = webcam_output3) |
|
|
|
submit_detect_web.click(fn=show_tracking, |
|
inputs= webcam_input, |
|
outputs = webcam_output4) |
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
if __name__ == "__main__": |
|
run_UI() |
|
|