Anti-hero commited on
Commit
a7ff454
1 Parent(s): 2188260

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +59 -27
handler.py CHANGED
@@ -1,41 +1,73 @@
1
  import cv2
2
  import numpy as np
3
  import tensorflow as tf
 
4
  import random
5
  from typing import Any, Dict, List
6
 
7
- class EndpointHandler():
8
- def __init__(self, path=""):
9
- self.model = tf.saved_model.load(f'{path}/my_model')
10
- self.classes_1 = ["RoadAccidents", "Fighting", "NormalVideos"]
11
- self.locations = ['Miami', 'Smouha', 'Mandara', 'Sporting', 'Montazah']
12
 
 
13
 
14
- def get_top_k(self, probs, k=1, label_map=None):
15
- if label_map is None:
16
- label_map = self.classes_1
17
- top_predictions = tf.argsort(probs, axis=-1, direction='DESCENDING')[:k]
18
- top_labels = tf.gather(label_map, top_predictions, axis=-1)
19
- top_labels = [label.decode('utf8') for label in top_labels.numpy()]
20
- top_probs = tf.gather(probs, top_predictions, axis=-1).numpy()
21
- return top_labels[0]
22
 
23
- def perform_action_recognition(self, frame, k=1):
24
- outputs = self.model.signatures['serving_default'](image=frame)
25
- probs = tf.nn.softmax(outputs['classifier_head_1'])
26
- return self.get_top_k(probs[0], k=k)
 
 
27
 
28
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
29
- if isinstance(data, str):
30
- data = json.loads(data) # Parse JSON string to dictionary
 
 
 
 
 
31
 
32
- frame = np.array(data.get("frame"))
33
 
34
- if frame is None:
35
- raise ValueError("'frame' is missing from the request body")
36
 
37
- if not isinstance(frame, np.ndarray):
38
- raise ValueError(f"Expected 'frame' to be a np.ndarray, but found {type(frame)}")
 
 
39
 
40
- prediction = self.perform_action_recognition(frame)
41
- return {"prediction": prediction}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
  import numpy as np
3
  import tensorflow as tf
4
+ import time
5
  import random
6
  from typing import Any, Dict, List
7
 
8
+ class EndpointHandler:
 
 
 
 
9
 
10
+ def __init__(self, path=""):
11
 
12
+ self.model = tf.saved_model.load(f'{path}/my_model')
13
+ self.classes_1 = ["RoadAccidents", "Fighting", "NormalVideos"]
14
+ self.locations = ['Miami', 'Smouha', 'Mandara', 'Sporting', 'Montazah']
 
 
 
 
 
15
 
16
+ def preprocess_frame(self, frame):
17
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
18
+ frame = cv2.resize(frame, (224, 224))
19
+ frame = frame.astype('float32') / 255.0
20
+ frame = np.expand_dims(frame, axis=0)
21
+ return frame
22
 
23
+ def get_top_k(self, probs, k=1, label_map=None):
24
+ if label_map is None:
25
+ label_map = self.classes_1
26
+ top_predictions = tf.argsort(probs, axis=-1, direction='DESCENDING')[:k]
27
+ top_labels = tf.gather(label_map, top_predictions, axis=-1)
28
+ top_labels = [label.decode('utf8') for label in top_labels.numpy()]
29
+ top_probs = tf.gather(probs, top_predictions, axis=-1).numpy()
30
+ return top_labels[0]
31
 
32
+ def perform_action_recognition(self, url, k=1):
33
 
34
+ cap = cv2.VideoCapture(url)
35
+ start_time = time.time()
36
 
37
+ while True:
38
+ ret, frame = cap.read()
39
+ if not ret:
40
+ break
41
 
42
+ preprocessed_frame = self.preprocess_frame(frame)
43
+ outputs = self.model.signatures['serving_default'](image=preprocessed_frame[tf.newaxis])
44
+ probs = tf.nn.softmax(outputs['classifier_head_1'])
45
+ current_time = time.time() - start_time
46
+ m, s = divmod(current_time, 60)
47
+ h, m = divmod(m, 60)
48
+ ip_address = url.split("/")[-1]
49
+ output = {
50
+ "class": self.get_top_k(probs[0], k=k),
51
+ "elapsed_time": f"{int(h):02d}:{int(m):02d}:{int(s):02d}",
52
+ "location": random.choice(self.locations),
53
+ "ip_address": ip_address
54
+ }
55
+ yield output
56
+
57
+ cap.release()
58
+
59
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
60
+
61
+ url = data.get("url")
62
+
63
+ if url is None:
64
+ raise ValueError("'url' is missing from the request body")
65
+
66
+ if not isinstance(url, str):
67
+ raise ValueError(f"Expected 'url' to be a str, but found {type(url)}")
68
+
69
+ outputs = []
70
+ for output in self.perform_action_recognition(url):
71
+ outputs.append(output)
72
+
73
+ return outputs