Christen Millerdurai commited on
Commit
9564652
1 Parent(s): 15bc41b
Files changed (3) hide show
  1. .gitignore +2 -1
  2. app.py +150 -15
  3. demo.py → infererence.py +45 -96
.gitignore CHANGED
@@ -159,4 +159,5 @@ cython_debug/
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
  src/Ev2Hands/outputs
162
- src/HandSimulator/logs
 
 
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
  src/Ev2Hands/outputs
162
+ src/HandSimulator/logs
163
+ gradio_cached_examples/
app.py CHANGED
@@ -1,25 +1,160 @@
1
  import gradio as gr
2
- import requests
 
 
3
 
 
 
 
4
 
5
 
6
- import gradio as gr
7
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def video_identity(video):
11
- print(video)
12
- return video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
 
 
 
 
 
 
14
 
15
- demo = gr.Interface(video_identity,
16
- gr.Video(),
17
- "playable_video",
18
- examples=[
19
- os.path.join(os.path.dirname(__file__),
20
- "example/video.mp4")],
21
- cache_examples=True)
22
 
23
- if __name__ == "__main__":
24
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
25
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
 
6
+ import esim_py
7
+ from infererence import process_events, Ev2Hands
8
+ from settings import OUTPUT_HEIGHT, OUTPUT_WIDTH, REF_PERIOD
9
 
10
 
11
+ os.makedirs("temp", exist_ok=True)
12
+ ev2hands = Ev2Hands()
13
+
14
+
15
+ def get_frames(video_in, trim_in):
16
+ cap = cv2.VideoCapture(video_in)
17
+
18
+ fps = cap.get(cv2.CAP_PROP_FPS)
19
+ stop_frame = int(trim_in * fps)
20
+
21
+ print("video fps: " + str(fps))
22
+
23
+ frames = []
24
+ i = 0
25
+ while(cap.isOpened()):
26
+ ret, frame = cap.read()
27
+ if not ret:
28
+ break
29
+
30
+ frame = cv2.resize(frame, (OUTPUT_WIDTH, OUTPUT_HEIGHT))
31
+ frames.append(frame)
32
+
33
+ if i > stop_frame:
34
+ break
35
+
36
+ i += 1
37
+
38
+
39
+ cap.release()
40
+
41
+ return frames, fps
42
+
43
+
44
+
45
+ def infer(video_inp, trim_in, threshold):
46
+ frames, fps = get_frames(video_inp, trim_in)
47
+ ts_s = 1 / fps
48
+ ts_ns = ts_s * 1e9 # convert s to ns
49
+
50
+ POS_THRESHOLD = NEG_THRESHOLD = threshold
51
+
52
+ esim = esim_py.EventSimulator(POS_THRESHOLD, NEG_THRESHOLD, REF_PERIOD, 1e-4, True)
53
+ is_init = False
54
+
55
+ event_frame_vid_path = 'temp/event_video.mp4'
56
+ prediction_vid_path = 'temp/prediction_video.mp4'
57
+
58
+ height, width, _ = frames[0].shape
59
+ event_video = cv2.VideoWriter(event_frame_vid_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
60
+ prediction_video = cv2.VideoWriter(prediction_vid_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
61
+
62
+ for idx, frame in enumerate(frames):
63
+ frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
64
+ frame_log = np.log(frame_gray.astype("float32") / 255 + 1e-4)
65
+
66
+ current_ts_ns = idx * ts_ns
67
+
68
+ if not is_init:
69
+ esim.init(frame_log, current_ts_ns)
70
+ is_init = True
71
+ continue
72
+
73
+ events = esim.generateEventFromCVImage(frame_log, current_ts_ns)
74
+ data = process_events(events)
75
+
76
+ prediction_frame = ev2hands(data)
77
+ event_frame = data['event_frame'].cpu().numpy().astype(dtype=np.uint8)
78
+
79
+
80
+ event_video.write(event_frame)
81
+ prediction_video.write(prediction_frame)
82
+
83
+ event_video.release()
84
+ prediction_video.release()
85
+
86
+ return event_frame_vid_path, prediction_vid_path
87
+
88
 
89
+ title = """
90
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
91
+ <div
92
+ style="
93
+ display: inline-flex;
94
+ align-items: center;
95
+ gap: 0.8rem;
96
+ font-size: 1.75rem;
97
+ "
98
+ >
99
+ <h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;">
100
+ Pix2Pix Video
101
+ </h1>
102
+ </div>
103
+ <p style="margin-bottom: 10px; font-size: 94%">
104
+ Apply Instruct Pix2Pix Diffusion to a video
105
+ </p>
106
+ </div>
107
+ """
108
 
109
+ article = """
110
+
111
+ <div class="footer">
112
+ <p>
113
+ Examples by <a href="https://twitter.com/CitizenPlain" target="_blank">Nathan Shipley</a> •&nbsp;
114
+ Follow <a href="https://twitter.com/fffiloni" target="_blank">Sylvain Filoni</a> for future updates 🤗
115
+ </p>
116
+ </div>
117
+ <div id="may-like-container" style="display: flex;justify-content: center;flex-direction: column;align-items: center;margin-bottom: 30px;">
118
+ <p>You may also like: </p>
119
+ <div id="may-like-content" style="display:flex;flex-wrap: wrap;align-items:center;height:20px;">
120
+
121
+ <svg height="20" width="162" style="margin-left:4px;margin-bottom: 6px;">
122
+ <a href="https://huggingface.co/spaces/timbrooks/instruct-pix2pix" target="_blank">
123
+ <image href="https://img.shields.io/badge/🤗 Spaces-Instruct_Pix2Pix-blue" src="https://img.shields.io/badge/🤗 Spaces-Instruct_Pix2Pix-blue.png" height="20"/>
124
+ </a>
125
+ </svg>
126
+
127
+ </div>
128
+
129
+ </div>
130
+
131
+ """
132
 
133
+ with gr.Blocks(css='style.css') as demo:
134
+ with gr.Column(elem_id="col-container"):
135
+ gr.HTML(title)
136
+ with gr.Row():
137
+ with gr.Column():
138
+ video_inp = gr.Video(label="Video source", elem_id="input-vid")
139
+ with gr.Row():
140
+ trim_in = gr.Slider(label="Cut video at (s)", minimum=1, maximum=5, step=1, value=1)
141
+ threshold = gr.Slider(label="Event Threshold", minimum=0.1, maximum=1, step=0.05, value=0.5)
142
 
143
+ with gr.Column():
144
+ event_frame_out = gr.Video(label="Event Frame", elem_id="video-output")
145
+ prediction_out = gr.Video(label="Ev2Hands result", elem_id="video-output")
 
 
 
 
146
 
147
+ gr.HTML("""
148
+ <a style="display:inline-block" href="https://huggingface.co/spaces/fffiloni/Pix2Pix-Video?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>
149
+ work with longer videos / skip the queue:
150
+ """, elem_id="duplicate-container")
151
 
152
+ submit_btn = gr.Button("Run Ev2Hands")
153
+
154
+ inputs = [video_inp, trim_in, threshold]
155
+ outputs = [event_frame_out, prediction_out]
156
+ gr.HTML(article)
157
+
158
+ submit_btn.click(infer, inputs, outputs)
159
+
160
+ demo.queue(max_size=12).launch(server_name="0.0.0.0", server_port=7860)
demo.py → infererence.py RENAMED
@@ -138,84 +138,52 @@ def demo(net, device, data):
138
  return frames
139
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- def main():
143
- arg_parser.demo()
144
- os.makedirs('outputs', exist_ok=True)
145
-
146
- device = torch.device('cpu')
147
-
148
- net = TEHNetWrapper(device=device)
149
-
150
- save_path = os.environ['CHECKPOINT_PATH']
151
- batch_size = int(os.environ['BATCH_SIZE'])
152
-
153
- checkpoint = torch.load(save_path, map_location=device)
154
- net.load_state_dict(checkpoint['state_dict'], strict=True)
155
-
156
- renderer = pyrender.OffscreenRenderer(viewport_width=OUTPUT_WIDTH, viewport_height=OUTPUT_HEIGHT)
157
-
158
- scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3))
159
- light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8)
160
- light_pose = np.eye(4)
161
- light_pose[:3, 3] = np.array([0, -1, 1])
162
- scene.add(light, pose=light_pose)
163
- light_pose[:3, 3] = np.array([0, 1, 1])
164
- scene.add(light, pose=light_pose)
165
- light_pose[:3, 3] = np.array([1, 1, 2])
166
- scene.add(light, pose=light_pose)
167
-
168
- rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
169
-
170
- mano_hands = net.hands
171
-
172
- # camera = cv2.VideoCapture(0)
173
- input_video_stream = cv2.VideoCapture('video.mp4')
174
-
175
-
176
-
177
- video_fps = 25
178
- video = cv2.VideoWriter('outputs/video.mp4', cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (3 * OUTPUT_WIDTH, OUTPUT_HEIGHT))
179
-
180
-
181
- POS_THRESHOLD = 0.5
182
- NEG_THRESHOLD = 0.5
183
- REF_PERIOD = 0.000
184
-
185
- esim = esim_py.EventSimulator(POS_THRESHOLD, NEG_THRESHOLD, REF_PERIOD, 1e-4, True)
186
-
187
-
188
- fps = cv2.CAP_PROP_FPS
189
- ts_s = 1 / fps
190
- ts_ns = ts_s * 1e9 # convert s to ns
191
-
192
- is_init = False
193
- idx = 0
194
- while True:
195
- _, frame_bgr = input_video_stream.read()
196
- frame_bgr = cv2.resize(frame_bgr, (OUTPUT_WIDTH, OUTPUT_HEIGHT))
197
- frame_gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
198
- frame_log = np.log(frame_gray.astype("float32") / 255 + 1e-4)
199
- height, width = frame_log.shape[:2]
200
-
201
- current_ts_ns = idx * ts_ns
202
-
203
- if not is_init:
204
- esim.init(frame_log, current_ts_ns)
205
- is_init = True
206
- idx += 1
207
-
208
- continue
209
- idx += 1
210
-
211
- events = esim.generateEventFromCVImage(frame_log, current_ts_ns)
212
- data = process_events(events)
213
-
214
- event_frame = data['event_frame'].cpu().numpy().astype(dtype=np.uint8)
215
-
216
- cv2.imwrite(f"outputs/event_frame_{idx}.png", event_frame)
217
 
218
- print(idx, event_frame.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  frame = demo(net=net, device=device, data=data)[0]
221
  seg_mask = frame['seg_mask']
@@ -231,33 +199,14 @@ def main():
231
  pred_meshes = trimesh.util.concatenate(pred_meshes)
232
  pred_meshes.apply_transform(rot)
233
 
234
- camera = MAIN_CAMERA
235
-
236
- nc = pyrender.Node(camera=camera, matrix=np.eye(4))
237
- scene.add_node(nc)
238
 
239
  mesh_node = pyrender.Node(mesh=pyrender.Mesh.from_trimesh(pred_meshes))
240
  scene.add_node(mesh_node)
241
  pred_rgb, depth = renderer.render(scene)
242
  scene.remove_node(mesh_node)
243
- scene.remove_node(nc)
244
 
245
  pred_rgb = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR)
246
  pred_rgb[pred_rgb == 255] = 0
247
 
248
- img_stack = np.hstack([event_frame, seg_mask, pred_rgb])
249
- video.write(img_stack)
250
-
251
- cv2.imshow('image', img_stack)
252
- c = cv2.waitKey(1)
253
-
254
- if c == ord('q'):
255
- video.release()
256
- exit(0)
257
-
258
- video.release()
259
-
260
-
261
- if __name__ == '__main__':
262
- main()
263
 
 
138
  return frames
139
 
140
 
141
+ class Ev2Hands:
142
+ def __init__(self) -> None:
143
+ arg_parser.demo()
144
+ device = torch.device('cpu')
145
+ net = TEHNetWrapper(device=device)
146
+
147
+ save_path = os.environ['CHECKPOINT_PATH']
148
+
149
+ checkpoint = torch.load(save_path, map_location=device)
150
+ net.load_state_dict(checkpoint['state_dict'], strict=True)
151
+
152
+ renderer = pyrender.OffscreenRenderer(viewport_width=OUTPUT_WIDTH, viewport_height=OUTPUT_HEIGHT)
153
+
154
+ scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3))
155
+ light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8)
156
+ light_pose = np.eye(4)
157
+ light_pose[:3, 3] = np.array([0, -1, 1])
158
+ scene.add(light, pose=light_pose)
159
+ light_pose[:3, 3] = np.array([0, 1, 1])
160
+ scene.add(light, pose=light_pose)
161
+ light_pose[:3, 3] = np.array([1, 1, 2])
162
+ scene.add(light, pose=light_pose)
163
+
164
+ camera = MAIN_CAMERA
165
+ nc = pyrender.Node(camera=camera, matrix=np.eye(4))
166
+ scene.add_node(nc)
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
170
+
171
+ mano_hands = net.hands
172
+
173
+ self.net = net
174
+ self.device = device
175
+ self.mano_hands = mano_hands
176
+ self.rot = rot
177
+ self.renderer = renderer
178
+ self.scene = scene
179
+
180
+ def __call__(self, data):
181
+ net = self.net
182
+ device = self.device
183
+ mano_hands = self.mano_hands
184
+ rot = self.rot
185
+ renderer = self.renderer
186
+ scene = self.scene
187
 
188
  frame = demo(net=net, device=device, data=data)[0]
189
  seg_mask = frame['seg_mask']
 
199
  pred_meshes = trimesh.util.concatenate(pred_meshes)
200
  pred_meshes.apply_transform(rot)
201
 
 
 
 
 
202
 
203
  mesh_node = pyrender.Node(mesh=pyrender.Mesh.from_trimesh(pred_meshes))
204
  scene.add_node(mesh_node)
205
  pred_rgb, depth = renderer.render(scene)
206
  scene.remove_node(mesh_node)
 
207
 
208
  pred_rgb = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR)
209
  pred_rgb[pred_rgb == 255] = 0
210
 
211
+ return pred_rgb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212