isLinXu commited on
Commit
0e66889
1 Parent(s): aeac74c

update app.py

Browse files
Files changed (2) hide show
  1. app.py +81 -0
  2. requirements.txt +20 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ os.system("pip install xtcocotools>=1.12")
5
+ os.system("pip install 'mmengine>=0.6.0'")
6
+ os.system("pip install 'mmcv>=2.0.0rc4,<2.1.0'")
7
+ os.system("pip install 'mmdet>=3.0.0,<4.0.0'")
8
+ os.system("pip install 'mmpose'")
9
+
10
+ import PIL
11
+ import cv2
12
+ import mmpose
13
+ import numpy as np
14
+
15
+ import torch
16
+ from mmpose.apis import MMPoseInferencer
17
+ import gradio as gr
18
+
19
+ import warnings
20
+
21
+ warnings.filterwarnings("ignore")
22
+
23
+ mmpose_model_list = ["human", "hand", "face", "animal", "wholebody",
24
+ "vitpose", "vitpose-s", "vitpose-b", "vitpose-l", "vitpose-h"]
25
+
26
+
27
+ def save_image(img, img_path):
28
+ # Convert PIL image to OpenCV image
29
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
30
+ # Save OpenCV image
31
+ cv2.imwrite(img_path, img)
32
+
33
+
34
+ def download_test_image():
35
+ # Images
36
+ torch.hub.download_url_to_file(
37
+ 'https://user-images.githubusercontent.com/59380685/266264420-21575a83-4057-41cf-8a4a-b3ea6f332d79.jpg',
38
+ 'bus.jpg')
39
+ torch.hub.download_url_to_file(
40
+ 'https://user-images.githubusercontent.com/59380685/266264536-82afdf58-6b9a-4568-b9df-551ee72cb6d9.jpg',
41
+ 'dogs.jpg')
42
+ torch.hub.download_url_to_file(
43
+ 'https://user-images.githubusercontent.com/59380685/266264600-9d0c26ca-8ba6-45f2-b53b-4dc98460c43e.jpg',
44
+ 'zidane.jpg')
45
+
46
+
47
+ def predict_pose(img, model_name, out_dir):
48
+ img_path = "input_img.jpg"
49
+ save_image(img, img_path)
50
+ device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
51
+ inferencer = MMPoseInferencer(model_name, device=device)
52
+ result_generator = inferencer(img_path, show=False, out_dir=out_dir)
53
+ result = next(result_generator)
54
+ save_dir = './output/visualizations/'
55
+ out_img_path = save_dir + img_path
56
+ out_img = PIL.Image.open(out_img_path)
57
+ return out_img
58
+
59
+ out_dir = "./output/visualizations/"
60
+ if not os.path.exists(out_dir):
61
+ os.makedirs(out_dir)
62
+ download_test_image()
63
+ input_image = gr.inputs.Image(type='pil', label="Original Image")
64
+ model_name = gr.inputs.Dropdown(choices=[m for m in mmpose_model_list], label='Model')
65
+ out_dir = gr.inputs.Textbox(label="Output Directory", default="./output")
66
+ output_image = gr.outputs.Image(type="pil", label="Output Image")
67
+
68
+ examples = [
69
+ ['zidane.jpg', 'human'],
70
+ ['dogs.jpg', 'animal'],
71
+ ]
72
+ title = "MMPose detection web demo"
73
+ description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmpose/main/resources/mmpose-logo.png' width='450''/><div>" \
74
+ "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmpose'>MMPose</a> MMPose 是一款基于 PyTorch 的姿态分析的开源工具箱,是 OpenMMLab 项目的成员之一。" \
75
+ "OpenMMLab Pose Estimation Toolbox and Benchmark..</p>"
76
+ article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmpose'>MMPose</a></p>" \
77
+ "<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
78
+
79
+ iface = gr.Interface(fn=predict_pose, inputs=[input_image, model_name, out_dir], outputs=output_image,
80
+ examples=examples, title=title, description=description, article=article)
81
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wget~=3.2
2
+ opencv-python~=4.6.0.66
3
+ numpy~=1.23.0
4
+ torch~=1.13.1
5
+ torchvision~=0.14.1
6
+ pillow~=9.4.0
7
+ gradio~=3.42.0
8
+ ultralytics~=8.0.169
9
+ pyyaml~=6.0
10
+ wandb~=0.13.11
11
+ tqdm~=4.65.0
12
+ matplotlib~=3.7.1
13
+ pandas~=2.0.0
14
+ seaborn~=0.12.2
15
+ requests~=2.31.0
16
+ psutil~=5.9.4
17
+ thop~=0.1.1-2209072238
18
+ timm~=0.9.2
19
+ super-gradients~=3.2.0
20
+ openmim