Ayushnangia commited on
Commit
9ef9ef2
1 Parent(s): e8dd937
Files changed (3) hide show
  1. __init__.py +1 -0
  2. raft.py +117 -0
  3. utils.py +135 -0
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .raft import Raft
raft.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import time
3
+ import numpy as np
4
+ import onnx
5
+ import onnxruntime
6
+
7
+ from .utils import flow_to_image
8
+
9
+ class Raft():
10
+
11
+ def __init__(self, model_path):
12
+
13
+ # Initialize model
14
+ self.initialize_model(model_path)
15
+
16
+ def __call__(self, img1, img2):
17
+
18
+ return self.estimate_flow(img1, img2)
19
+
20
+ def initialize_model(self, model_path):
21
+
22
+ self.session = onnxruntime.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
23
+
24
+ # Get model info
25
+ self.get_input_details()
26
+ self.get_output_details()
27
+
28
+ def estimate_flow(self, img1, img2):
29
+
30
+ input_tensor1 = self.prepare_input(img1)
31
+ input_tensor2 = self.prepare_input(img2)
32
+
33
+ outputs = self.inference(input_tensor1, input_tensor2)
34
+
35
+ self.flow_map = self.process_output(outputs)
36
+
37
+ return self.flow_map
38
+
39
+ def prepare_input(self, img):
40
+
41
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
42
+
43
+ self.img_height, self.img_width = img.shape[:2]
44
+
45
+ img_input = cv2.resize(img, (self.input_width,self.input_height))
46
+
47
+ # img_input = img_input/255
48
+ img_input = img_input.transpose(2, 0, 1)
49
+ img_input = img_input[np.newaxis,:,:,:]
50
+
51
+ return img_input.astype(np.float32)
52
+
53
+ def inference(self, input_tensor1, input_tensor2):
54
+
55
+ # start = time.time()
56
+ outputs = self.session.run(self.output_names, {self.input_names[0]: input_tensor1,
57
+ self.input_names[1]: input_tensor2})
58
+
59
+ # print(time.time() - start)
60
+ return outputs
61
+
62
+ def process_output(self, output):
63
+
64
+ flow_map = output[1][0].transpose(1, 2, 0)
65
+
66
+ return flow_map
67
+
68
+ def draw_flow(self):
69
+
70
+ # Convert flow to image
71
+ flow_img = flow_to_image(self.flow_map)
72
+
73
+ # Convert to BGR
74
+ flow_img = cv2.cvtColor(flow_img, cv2.COLOR_RGB2BGR)
75
+
76
+ # Resize the depth map to match the input image shape
77
+ return cv2.resize(flow_img, (self.img_width,self.img_height))
78
+
79
+ def get_input_details(self):
80
+
81
+ model_inputs = self.session.get_inputs()
82
+ self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
83
+
84
+ self.input_shape = model_inputs[0].shape
85
+ self.input_height = self.input_shape[2]
86
+ self.input_width = self.input_shape[3]
87
+
88
+ def get_output_details(self):
89
+
90
+ model_outputs = self.session.get_outputs()
91
+ self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
92
+
93
+ self.output_shape = model_outputs[0].shape
94
+ self.output_height = self.output_shape[2]
95
+ self.output_width = self.output_shape[3]
96
+
97
+ if __name__ == '__main__':
98
+
99
+ from imread_from_url import imread_from_url
100
+
101
+ # Initialize model
102
+ model_path='../models/raft_things_iter20_480x640.onnx'
103
+ flow_estimator = Raft(model_path)
104
+
105
+ # Read inference image
106
+ img1 = imread_from_url("https://github.com/princeton-vl/RAFT/blob/master/demo-frames/frame_0016.png?raw=true")
107
+ img2 = imread_from_url("https://github.com/princeton-vl/RAFT/blob/master/demo-frames/frame_0025.png?raw=true")
108
+
109
+ # Estimate flow and colorize it
110
+ flow_map = flow_estimator(img1, img2)
111
+ flow_img = flow_estimator.draw_flow()
112
+
113
+ combined_img = np.hstack((img1, img2, flow_img))
114
+
115
+ cv2.namedWindow("Estimated flow", cv2.WINDOW_NORMAL)
116
+ cv2.imshow("Estimated flow", combined_img)
117
+ cv2.waitKey(0)
utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ref: https://github.com/liruoteng/OpticalFlowToolkit/blob/5cf87b947a0032f58c922bbc22c0afb30b90c418/lib/flowlib.py#L249
2
+
3
+ import numpy as np
4
+
5
+ UNKNOWN_FLOW_THRESH = 1e7
6
+
7
+ def make_color_wheel():
8
+ """
9
+ Generate color wheel according Middlebury color code
10
+ :return: Color wheel
11
+ """
12
+ RY = 15
13
+ YG = 6
14
+ GC = 4
15
+ CB = 11
16
+ BM = 13
17
+ MR = 6
18
+
19
+ ncols = RY + YG + GC + CB + BM + MR
20
+
21
+ colorwheel = np.zeros([ncols, 3])
22
+
23
+ col = 0
24
+
25
+ # RY
26
+ colorwheel[0:RY, 0] = 255
27
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
28
+ col += RY
29
+
30
+ # YG
31
+ colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
32
+ colorwheel[col:col+YG, 1] = 255
33
+ col += YG
34
+
35
+ # GC
36
+ colorwheel[col:col+GC, 1] = 255
37
+ colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
38
+ col += GC
39
+
40
+ # CB
41
+ colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
42
+ colorwheel[col:col+CB, 2] = 255
43
+ col += CB
44
+
45
+ # BM
46
+ colorwheel[col:col+BM, 2] = 255
47
+ colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
48
+ col += + BM
49
+
50
+ # MR
51
+ colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
52
+ colorwheel[col:col+MR, 0] = 255
53
+
54
+ return colorwheel
55
+
56
+ colorwheel = make_color_wheel()
57
+
58
+ def compute_color(u, v):
59
+ """
60
+ compute optical flow color map
61
+ :param u: optical flow horizontal map
62
+ :param v: optical flow vertical map
63
+ :return: optical flow in color code
64
+ """
65
+ [h, w] = u.shape
66
+ img = np.zeros([h, w, 3])
67
+ nanIdx = np.isnan(u) | np.isnan(v)
68
+ u[nanIdx] = 0
69
+ v[nanIdx] = 0
70
+
71
+ ncols = np.size(colorwheel, 0)
72
+
73
+ rad = np.sqrt(u**2+v**2)
74
+
75
+ a = np.arctan2(-v, -u) / np.pi
76
+
77
+ fk = (a+1) / 2 * (ncols - 1) + 1
78
+
79
+ k0 = np.floor(fk).astype(int)
80
+
81
+ k1 = k0 + 1
82
+ k1[k1 == ncols+1] = 1
83
+ f = fk - k0
84
+
85
+ for i in range(0, np.size(colorwheel,1)):
86
+ tmp = colorwheel[:, i]
87
+ col0 = tmp[k0-1] / 255
88
+ col1 = tmp[k1-1] / 255
89
+ col = (1-f) * col0 + f * col1
90
+
91
+ idx = rad <= 1
92
+ col[idx] = 1-rad[idx]*(1-col[idx])
93
+ notidx = np.logical_not(idx)
94
+
95
+ col[notidx] *= 0.75
96
+ img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
97
+
98
+ return img
99
+
100
+ def flow_to_image(flow):
101
+ """
102
+ Convert flow into middlebury color code image
103
+ :param flow: optical flow map
104
+ :return: optical flow image in middlebury color
105
+ """
106
+ u = flow[:, :, 0]
107
+ v = flow[:, :, 1]
108
+
109
+ maxu = -999.
110
+ maxv = -999.
111
+ minu = 999.
112
+ minv = 999.
113
+
114
+ idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
115
+ u[idxUnknow] = 0
116
+ v[idxUnknow] = 0
117
+
118
+ maxu = max(maxu, np.max(u))
119
+ minu = min(minu, np.min(u))
120
+
121
+ maxv = max(maxv, np.max(v))
122
+ minv = min(minv, np.min(v))
123
+
124
+ rad = np.sqrt(u ** 2 + v ** 2)
125
+ maxrad = max(-1, np.max(rad))
126
+
127
+ u = u/(maxrad + np.finfo(float).eps)
128
+ v = v/(maxrad + np.finfo(float).eps)
129
+
130
+ img = compute_color(u, v)
131
+
132
+ idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
133
+ img[idx] = 0
134
+
135
+ return np.uint8(img)