huntrezz commited on
Commit
0143794
·
verified ·
1 Parent(s): f8b3886

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -50
app.py CHANGED
@@ -1,9 +1,13 @@
1
  import cv2
2
  import torch
3
- from transformers import DPTForDepthEstimation, DPTImageProcessor
4
  import numpy as np
 
5
  import time
6
  import warnings
 
 
 
 
7
  warnings.filterwarnings("ignore", message="It looks like you are trying to rescale already rescaled images.")
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -25,72 +29,83 @@ def manual_normalize(depth_map):
25
  return np.zeros_like(depth_map, dtype=np.uint8)
26
 
27
  frame_skip = 4
28
- frame_count = 0
29
  color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
30
 
31
- prev_frame_time = 0
32
 
33
- while True:
34
- ret, frame = cap.read()
35
- if not ret:
36
- break
 
 
37
 
38
- frame_count += 1
39
- if frame_count % frame_skip != 0:
40
- continue
 
 
 
41
 
42
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
43
- resized_frame = resize_image(rgb_frame)
 
44
 
45
- inputs = processor(images=resized_frame, return_tensors="pt").to(device)
46
- inputs = {k: v.to(torch.float16) for k, v in inputs.items()}
 
 
47
 
48
- with torch.no_grad():
49
- outputs = model(**inputs)
50
- predicted_depth = outputs.predicted_depth
51
 
52
- depth_map = predicted_depth.squeeze().cpu().numpy()
 
53
 
54
- # Check Input Data
55
- print(f"depth_map shape: {depth_map.shape}")
56
- print(f"depth_map min: {np.min(depth_map)}, max: {np.max(depth_map)}")
57
- print(f"depth_map dtype: {depth_map.dtype}")
58
 
59
- # Handle invalid values
60
- depth_map = np.nan_to_num(depth_map, nan=0.0, posinf=0.0, neginf=0.0)
 
61
 
62
- # Ensure depth_map is in float32 format
63
- depth_map = depth_map.astype(np.float32)
64
 
65
- # Check for zero-sized arrays
66
- if depth_map.size == 0:
67
- print("Error: depth_map is empty")
68
- depth_map = np.zeros((256, 256), dtype=np.uint8)
69
- else:
70
- # Handle empty or constant arrays
71
- if np.any(depth_map) and np.min(depth_map) != np.max(depth_map):
72
- depth_map = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
73
  else:
74
- depth_map = np.zeros_like(depth_map, dtype=np.uint8)
 
 
 
75
 
76
- # Use manual normalization as a fallback
77
- if np.all(depth_map == 0):
78
- depth_map = manual_normalize(depth_map)
79
 
80
- depth_map_colored = cv2.applyColorMap(depth_map, color_map)
81
- depth_map_colored = cv2.resize(depth_map_colored, (frame.shape[1], frame.shape[0]))
 
 
 
 
82
 
83
- combined = np.hstack((frame, depth_map_colored))
 
 
84
 
85
- new_frame_time = time.time()
86
- fps = 1 / (new_frame_time - prev_frame_time)
87
- prev_frame_time = new_frame_time
88
- cv2.putText(combined, f"FPS: {int(fps)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
89
 
90
- cv2.imshow('Webcam and Depth Map', combined)
 
91
 
92
- if cv2.waitKey(1) & 0xFF == ord('q'):
93
- break
 
94
 
95
- cap.release()
96
- cv2.destroyAllWindows()
 
1
  import cv2
2
  import torch
 
3
  import numpy as np
4
+ from transformers import DPTForDepthEstimation, DPTImageProcessor
5
  import time
6
  import warnings
7
+ import asyncio
8
+ import json
9
+ import websockets
10
+
11
  warnings.filterwarnings("ignore", message="It looks like you are trying to rescale already rescaled images.")
12
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
29
  return np.zeros_like(depth_map, dtype=np.uint8)
30
 
31
  frame_skip = 4
 
32
  color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
33
 
34
+ connected = set()
35
 
36
+ async def broadcast(message):
37
+ for websocket in connected:
38
+ try:
39
+ await websocket.send(message)
40
+ except websockets.exceptions.ConnectionClosed:
41
+ connected.remove(websocket)
42
 
43
+ async def handler(websocket, path):
44
+ connected.add(websocket)
45
+ try:
46
+ await websocket.wait_closed()
47
+ finally:
48
+ connected.remove(websocket)
49
 
50
+ async def process_frames():
51
+ frame_count = 0
52
+ prev_frame_time = 0
53
 
54
+ while True:
55
+ ret, frame = cap.read()
56
+ if not ret:
57
+ break
58
 
59
+ frame_count += 1
60
+ if frame_count % frame_skip != 0:
61
+ continue
62
 
63
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
64
+ resized_frame = resize_image(rgb_frame)
65
 
66
+ inputs = processor(images=resized_frame, return_tensors="pt").to(device)
67
+ inputs = {k: v.to(torch.float16) for k, v in inputs.items()}
 
 
68
 
69
+ with torch.no_grad():
70
+ outputs = model(**inputs)
71
+ predicted_depth = outputs.predicted_depth
72
 
73
+ depth_map = predicted_depth.squeeze().cpu().numpy()
 
74
 
75
+ depth_map = np.nan_to_num(depth_map, nan=0.0, posinf=0.0, neginf=0.0)
76
+ depth_map = depth_map.astype(np.float32)
77
+
78
+ if depth_map.size == 0:
79
+ depth_map = np.zeros((256, 256), dtype=np.uint8)
 
 
 
80
  else:
81
+ if np.any(depth_map) and np.min(depth_map) != np.max(depth_map):
82
+ depth_map = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
83
+ else:
84
+ depth_map = np.zeros_like(depth_map, dtype=np.uint8)
85
 
86
+ if np.all(depth_map == 0):
87
+ depth_map = manual_normalize(depth_map)
 
88
 
89
+ data = {
90
+ 'depthMap': depth_map.tolist(),
91
+ 'rgbFrame': rgb_frame.tolist()
92
+ }
93
+
94
+ await broadcast(json.dumps(data))
95
 
96
+ new_frame_time = time.time()
97
+ fps = 1 / (new_frame_time - prev_frame_time)
98
+ prev_frame_time = new_frame_time
99
 
100
+ if cv2.waitKey(1) & 0xFF == ord('q'):
101
+ break
 
 
102
 
103
+ cap.release()
104
+ cv2.destroyAllWindows()
105
 
106
+ async def main():
107
+ server = await websockets.serve(handler, "localhost", 8765)
108
+ await asyncio.gather(server.wait_closed(), process_frames())
109
 
110
+ if __name__ == "__main__":
111
+ asyncio.run(main())