harpreetsahota commited on
Commit
d13f105
1 Parent(s): a0ae2b6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+
3
+ import cv2
4
+ import gradio as gr
5
+ import numpy as np
6
+ import requests
7
+ from PIL import Image
8
+
9
+
10
+ from super_gradients.common.object_names import Models
11
+ from super_gradients.training import models
12
+ from super_gradients.training.utils.visualization.detection import draw_bbox
13
+ from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization
14
+
15
+ # Initialize your pose estimation model
16
+ yolo_nas_pose = models.get("yolo_nas_pose_l",
17
+ num_classes=17,
18
+ checkpoint_path="/content/yolo_nas_pose_l_coco_pose.pth")
19
+
20
+ def process_and_predict(url=None,
21
+ image=None,
22
+ confidence=0.5,
23
+ iou=0.5):
24
+ # If a URL is provided, use it directly for prediction
25
+ if url is not None and url.strip() != "":
26
+ response = requests.get(url)
27
+ image = Image.open(BytesIO(response.content))
28
+ image = np.array(image)
29
+ result = yolo_nas_pose.predict(image, conf=confidence,iou=iou)
30
+ # If a file is uploaded, read it, convert it to a numpy array and use it for prediction
31
+ elif image is not None:
32
+ result = yolo_nas_pose.predict(image, conf=confidence,iou=iou)
33
+ else:
34
+ return None # If no input is provided, return None
35
+
36
+ # Extract prediction data
37
+ image_prediction = result._images_prediction_lst[0]
38
+
39
+ pose_data = image_prediction.prediction
40
+
41
+ # Visualize the prediction
42
+ output_image = PoseVisualization.draw_poses(
43
+ image=image_prediction.image,
44
+ poses=pose_data.poses,
45
+ boxes=pose_data.bboxes_xyxy,
46
+ scores=pose_data.scores,
47
+ is_crowd=None,
48
+ edge_links=pose_data.edge_links,
49
+ edge_colors=pose_data.edge_colors,
50
+ keypoint_colors=pose_data.keypoint_colors,
51
+ joint_thickness=2,
52
+ box_thickness=2,
53
+ keypoint_radius=5
54
+ )
55
+
56
+ blank_image = np.zeros_like(image_prediction.image)
57
+
58
+ skeleton_image = PoseVisualization.draw_poses(
59
+ image=blank_image,
60
+ poses=pose_data.poses,
61
+ boxes=pose_data.bboxes_xyxy,
62
+ scores=pose_data.scores,
63
+ is_crowd=None,
64
+ edge_links=pose_data.edge_links,
65
+ edge_colors=pose_data.edge_colors,
66
+ keypoint_colors=pose_data.keypoint_colors,
67
+ joint_thickness=2,
68
+ box_thickness=2,
69
+ keypoint_radius=5
70
+ )
71
+
72
+ # Convert the resulting visualization to a PIL Image
73
+ # output_image_pil = Image.fromarray(output_image.astype('uint8'), 'RGB')
74
+
75
+ # Return the PIL Image directly
76
+ return output_image, skeleton_image
77
+
78
+ # Define the Gradio interface
79
+ iface = gr.Interface(
80
+ fn=process_and_predict,
81
+ inputs=[
82
+ gr.Textbox(placeholder="Enter Image URL", label="Image URL"),
83
+ gr.Image(label="Upload Image", type='numpy'),
84
+ gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Confidence Threshold"),
85
+ gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="IoU Threshold")
86
+ ],
87
+ outputs=[
88
+ gr.components.Image(label="Estimated Pose"),
89
+ gr.components.Image(label="Skeleton Only")
90
+ ],
91
+ title="YOLO-NAS-Pose Demo",
92
+ description="Upload an image, enter an image URL, or use your webcam to use a pretrained YOLO-NAS-Pose L for inference. You can check out the ",
93
+ live=False,
94
+ allow_flagging=False,
95
+
96
+ )
97
+
98
+ # Launch the interface
99
+ iface.launch()
100
+