harpreetsahota's picture
Create app.py
d13f105
raw
history blame
No virus
3.34 kB
from io import BytesIO
import cv2
import gradio as gr
import numpy as np
import requests
from PIL import Image
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training.utils.visualization.detection import draw_bbox
from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization
# Initialize your pose estimation model
yolo_nas_pose = models.get("yolo_nas_pose_l",
num_classes=17,
checkpoint_path="/content/yolo_nas_pose_l_coco_pose.pth")
def process_and_predict(url=None,
image=None,
confidence=0.5,
iou=0.5):
# If a URL is provided, use it directly for prediction
if url is not None and url.strip() != "":
response = requests.get(url)
image = Image.open(BytesIO(response.content))
image = np.array(image)
result = yolo_nas_pose.predict(image, conf=confidence,iou=iou)
# If a file is uploaded, read it, convert it to a numpy array and use it for prediction
elif image is not None:
result = yolo_nas_pose.predict(image, conf=confidence,iou=iou)
else:
return None # If no input is provided, return None
# Extract prediction data
image_prediction = result._images_prediction_lst[0]
pose_data = image_prediction.prediction
# Visualize the prediction
output_image = PoseVisualization.draw_poses(
image=image_prediction.image,
poses=pose_data.poses,
boxes=pose_data.bboxes_xyxy,
scores=pose_data.scores,
is_crowd=None,
edge_links=pose_data.edge_links,
edge_colors=pose_data.edge_colors,
keypoint_colors=pose_data.keypoint_colors,
joint_thickness=2,
box_thickness=2,
keypoint_radius=5
)
blank_image = np.zeros_like(image_prediction.image)
skeleton_image = PoseVisualization.draw_poses(
image=blank_image,
poses=pose_data.poses,
boxes=pose_data.bboxes_xyxy,
scores=pose_data.scores,
is_crowd=None,
edge_links=pose_data.edge_links,
edge_colors=pose_data.edge_colors,
keypoint_colors=pose_data.keypoint_colors,
joint_thickness=2,
box_thickness=2,
keypoint_radius=5
)
# Convert the resulting visualization to a PIL Image
# output_image_pil = Image.fromarray(output_image.astype('uint8'), 'RGB')
# Return the PIL Image directly
return output_image, skeleton_image
# Define the Gradio interface
iface = gr.Interface(
fn=process_and_predict,
inputs=[
gr.Textbox(placeholder="Enter Image URL", label="Image URL"),
gr.Image(label="Upload Image", type='numpy'),
gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Confidence Threshold"),
gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="IoU Threshold")
],
outputs=[
gr.components.Image(label="Estimated Pose"),
gr.components.Image(label="Skeleton Only")
],
title="YOLO-NAS-Pose Demo",
description="Upload an image, enter an image URL, or use your webcam to use a pretrained YOLO-NAS-Pose L for inference. You can check out the ",
live=False,
allow_flagging=False,
)
# Launch the interface
iface.launch()