movinet62 / handler.py
Anti-hero's picture
Update handler.py
2188260 verified
raw
history blame
No virus
1.61 kB
import cv2
import numpy as np
import tensorflow as tf
import random
from typing import Any, Dict, List
class EndpointHandler():
def __init__(self, path=""):
self.model = tf.saved_model.load(f'{path}/my_model')
self.classes_1 = ["RoadAccidents", "Fighting", "NormalVideos"]
self.locations = ['Miami', 'Smouha', 'Mandara', 'Sporting', 'Montazah']
def get_top_k(self, probs, k=1, label_map=None):
if label_map is None:
label_map = self.classes_1
top_predictions = tf.argsort(probs, axis=-1, direction='DESCENDING')[:k]
top_labels = tf.gather(label_map, top_predictions, axis=-1)
top_labels = [label.decode('utf8') for label in top_labels.numpy()]
top_probs = tf.gather(probs, top_predictions, axis=-1).numpy()
return top_labels[0]
def perform_action_recognition(self, frame, k=1):
outputs = self.model.signatures['serving_default'](image=frame)
probs = tf.nn.softmax(outputs['classifier_head_1'])
return self.get_top_k(probs[0], k=k)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(data, str):
data = json.loads(data) # Parse JSON string to dictionary
frame = np.array(data.get("frame"))
if frame is None:
raise ValueError("'frame' is missing from the request body")
if not isinstance(frame, np.ndarray):
raise ValueError(f"Expected 'frame' to be a np.ndarray, but found {type(frame)}")
prediction = self.perform_action_recognition(frame)
return {"prediction": prediction}