ac5113 commited on
Commit
5ae9c92
·
1 Parent(s): 0da7b5a

initial app test

Browse files
Files changed (1) hide show
  1. app.py +215 -4
app.py CHANGED
@@ -1,7 +1,218 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import glob
4
+ import numpy as np
5
+ import cv2
6
+ import PIL.Image as pil_img
7
+ import sys
8
  import gradio as gr
9
 
10
+ import trimesh
11
+ import pyrender
12
 
13
+ print(os.path.abspath(__file__))
14
+ os.system('pip install /home/user/app/vendor/pyrender')
15
+ sys.path.append('/home/user/app/vendor/pyrender')
16
+ os.system('pip install gradio==3.47.1')
17
+ os.system('sh fetch_data.sh')
18
+
19
+ from models.deco import DECO
20
+ from common import constants
21
+
22
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
23
+
24
+ if torch.cuda.is_available():
25
+ device = torch.device('cuda')
26
+ else:
27
+ device = torch.device('cpu')
28
+
29
+ def initiate_model(model_path):
30
+ deco_model = DECO('hrnet', True, device)
31
+
32
+ print(f'Loading weights from {model_path}')
33
+ checkpoint = torch.load(model_path)
34
+ deco_model.load_state_dict(checkpoint['deco'], strict=True)
35
+
36
+ deco_model.eval()
37
+
38
+ return deco_model
39
+
40
+ def render_image(scene, img_res, img=None, viewer=False):
41
+ '''
42
+ Render the given pyrender scene and return the image. Can also overlay the mesh on an image.
43
+ '''
44
+ if viewer:
45
+ pyrender.Viewer(scene, use_raymond_lighting=True)
46
+ return 0
47
+ else:
48
+ r = pyrender.OffscreenRenderer(viewport_width=img_res,
49
+ viewport_height=img_res,
50
+ point_size=1.0)
51
+ color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
52
+ color = color.astype(np.float32) / 255.0
53
+
54
+ if img is not None:
55
+ valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
56
+ input_img = img.detach().cpu().numpy()
57
+ output_img = (color[:, :, :-1] * valid_mask +
58
+ (1 - valid_mask) * input_img)
59
+ else:
60
+ output_img = color
61
+ return output_img
62
+
63
+ def create_scene(mesh, img, focal_length=500, camera_center=250, img_res=500):
64
+ # Setup the scene
65
+ scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0],
66
+ ambient_light=(0.3, 0.3, 0.3))
67
+ # add mesh for camera
68
+ camera_pose = np.eye(4)
69
+ camera_rotation = np.eye(3, 3)
70
+ camera_translation = np.array([0., 0, 2.5])
71
+ camera_pose[:3, :3] = camera_rotation
72
+ camera_pose[:3, 3] = camera_rotation @ camera_translation
73
+ pyrencamera = pyrender.camera.IntrinsicsCamera(
74
+ fx=focal_length, fy=focal_length,
75
+ cx=camera_center, cy=camera_center)
76
+ scene.add(pyrencamera, pose=camera_pose)
77
+ # create and add light
78
+ light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=1)
79
+ light_pose = np.eye(4)
80
+ for lp in [[1, 1, 1], [-1, 1, 1], [1, -1, 1], [-1, -1, 1]]:
81
+ light_pose[:3, 3] = mesh.vertices.mean(0) + np.array(lp)
82
+ # out_mesh.vertices.mean(0) + np.array(lp)
83
+ scene.add(light, pose=light_pose)
84
+ # add body mesh
85
+ material = pyrender.MetallicRoughnessMaterial(
86
+ metallicFactor=0.0,
87
+ alphaMode='OPAQUE',
88
+ baseColorFactor=(1.0, 1.0, 0.9, 1.0))
89
+ mesh_images = []
90
+
91
+ # resize input image to fit the mesh image height
92
+ img_height = img_res
93
+ img_width = int(img_height * img.shape[1] / img.shape[0])
94
+ img = cv2.resize(img, (img_width, img_height))
95
+ mesh_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
96
+
97
+ for sideview_angle in [0, 90, 180, 270]:
98
+ out_mesh = mesh.copy()
99
+ rot = trimesh.transformations.rotation_matrix(
100
+ np.radians(sideview_angle), [0, 1, 0])
101
+ out_mesh.apply_transform(rot)
102
+ out_mesh = pyrender.Mesh.from_trimesh(
103
+ out_mesh,
104
+ material=material)
105
+ mesh_pose = np.eye(4)
106
+ scene.add(out_mesh, pose=mesh_pose, name='mesh')
107
+ output_img = render_image(scene, img_res)
108
+ output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
109
+ output_img = np.asarray(output_img)[:, :, :3]
110
+ mesh_images.append(output_img)
111
+ # delete the previous mesh
112
+ prev_mesh = scene.get_nodes(name='mesh').pop()
113
+ scene.remove_node(prev_mesh)
114
+
115
+ # show upside down view
116
+ for topview_angle in [90, 270]:
117
+ out_mesh = mesh.copy()
118
+ rot = trimesh.transformations.rotation_matrix(
119
+ np.radians(topview_angle), [1, 0, 0])
120
+ out_mesh.apply_transform(rot)
121
+ out_mesh = pyrender.Mesh.from_trimesh(
122
+ out_mesh,
123
+ material=material)
124
+ mesh_pose = np.eye(4)
125
+ scene.add(out_mesh, pose=mesh_pose, name='mesh')
126
+ output_img = render_image(scene, img_res)
127
+ output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
128
+ output_img = np.asarray(output_img)[:, :, :3]
129
+ mesh_images.append(output_img)
130
+ # delete the previous mesh
131
+ prev_mesh = scene.get_nodes(name='mesh').pop()
132
+ scene.remove_node(prev_mesh)
133
+
134
+ # stack images
135
+ IMG = np.hstack(mesh_images)
136
+ IMG = pil_img.fromarray(IMG)
137
+ IMG.thumbnail((3000, 3000))
138
+ return IMG
139
+
140
+ def main(args, img_src, out_dir, mesh_colour=[130, 130, 130, 255], annot_colour=[0, 255, 0, 255]):
141
+ if os.path.isdir(img_src):
142
+ images = glob.iglob(img_src + '/*', recursive=True)
143
+ else:
144
+ images = [img_src]
145
+
146
+ deco_model = initiate_model(args)
147
+
148
+ smpl_path = os.path.join(constants.SMPL_MODEL_DIR, 'smpl_neutral_tpose.ply')
149
+
150
+ for img_name in images:
151
+ img = cv2.imread(img_name)
152
+ img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC)
153
+ img = img.transpose(2,0,1)/255.0
154
+ img = img[np.newaxis,:,:,:]
155
+ img = torch.tensor(img, dtype = torch.float32).to(device)
156
+
157
+ cont, _, _ = deco_model(img)
158
+ cont = cont.detach().cpu().numpy().squeeze()
159
+ cont_smpl = []
160
+ for indx, i in enumerate(cont):
161
+ if i >= 0.5:
162
+ cont_smpl.append(indx)
163
+
164
+ img = img.detach().cpu().numpy()
165
+ img = np.transpose(img[0], (1, 2, 0))
166
+ img = img * 255
167
+ img = img.astype(np.uint8)
168
+
169
+ contact_smpl = np.zeros((1, 1, 6890))
170
+ contact_smpl[0][0][cont_smpl] = 1
171
+
172
+ body_model_smpl = trimesh.load(smpl_path, process=False)
173
+ for vert in range(body_model_smpl.visual.vertex_colors.shape[0]):
174
+ body_model_smpl.visual.vertex_colors[vert] = mesh_colour
175
+ body_model_smpl.visual.vertex_colors[cont_smpl] = annot_colour
176
+
177
+ rend = create_scene(body_model_smpl, img)
178
+ os.makedirs(os.path.join(out_dir, 'Renders'), exist_ok=True)
179
+ rend.save(os.path.join(out_dir, 'Renders', os.path.basename(img_name).split('.')[0] + '.png'))
180
+
181
+ mesh_out_dir = os.path.join(out_dir, 'Preds', os.path.basename(img_name).split('.')[0])
182
+ os.makedirs(mesh_out_dir, exist_ok=True)
183
+
184
+ print(f'Saving mesh to {mesh_out_dir}')
185
+ body_model_smpl.export(os.path.join(mesh_out_dir, 'pred.obj'))
186
+
187
+ return out_dir
188
+
189
+ with gr.Blocks(title="DECO", css=".gradio-container") as demo:
190
+
191
+ gr.HTML("""<div style="font-weight:bold; text-align:center; color:royalblue;">DECO</div>""")
192
+
193
+ with gr.Row():
194
+ with gr.Column():
195
+ input_image = gr.Image(label="Input image", type="pil")
196
+ with gr.Column():
197
+ output_image = gr.Image(label="Renders", type="pil")
198
+ output_meshes = gr.File(label="3D meshes")
199
+
200
+ gr.HTML("""<br/>""")
201
+
202
+ # with gr.Row():
203
+ # threshold = gr.Slider(0, 1.0, value=0.6, label='Detection Threshold')
204
+ # send_btn = gr.Button("Infer")
205
+ # send_btn.click(fn=main, inputs=[input_image, threshold], outputs=[output_image, output_meshes])
206
+
207
+ # example_images = gr.Examples([
208
+ # ['/home/user/app/assets/test1.png'],
209
+ # ['/home/user/app/assets/test2.jpg'],
210
+ # ['/home/user/app/assets/test3.jpg'],
211
+ # ['/home/user/app/assets/test4.jpg'],
212
+ # ['/home/user/app/assets/test5.jpg'],
213
+ # ],
214
+ # inputs=[input_image, 0.6])
215
+
216
+
217
+ #demo.queue()
218
+ demo.launch(debug=True)