brxerq commited on
Commit
da33dfa
1 Parent(s): e33a4ef

Update model_3.py

Browse files
Files changed (1) hide show
  1. model_3.py +50 -1
model_3.py CHANGED
@@ -20,4 +20,53 @@ else:
20
  PATH_TO_CKPT = os.path.join(MODEL_DIR, GRAPH_NAME)
21
  PATH_TO_LABELS = os.path.join(MODEL_DIR, LABELMAP_NAME)
22
 
23
- with open(PATH_TO_LABELS, '
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  PATH_TO_CKPT = os.path.join(MODEL_DIR, GRAPH_NAME)
21
  PATH_TO_LABELS = os.path.join(MODEL_DIR, LABELMAP_NAME)
22
 
23
+ with open(PATH_TO_LABELS, 'r') as f:
24
+ labels = [line.strip() for line in f.readlines()]
25
+
26
+ if labels[0] == '???':
27
+ del(labels[0])
28
+
29
+ interpreter = Interpreter(model_path=PATH_TO_CKPT)
30
+ interpreter.allocate_tensors()
31
+
32
+ input_details = interpreter.get_input_details()
33
+ output_details = interpreter.get_output_details()
34
+ height = input_details[0]['shape'][1]
35
+ width = input_details[0]['shape'][2]
36
+ floating_model = (input_details[0]['dtype'] == np.float32)
37
+
38
+ def detect_image(input_image):
39
+ image = np.array(input_image)
40
+ resized_image = cv2.resize(image, (640, 640))
41
+ result_image = perform_detection(resized_image, interpreter, labels, input_details, output_details, height, width, floating_model)
42
+ return Image.fromarray(result_image)
43
+
44
+ def detect_video(input_video):
45
+ cap = cv2.VideoCapture(input_video)
46
+ frames = []
47
+
48
+ while cap.isOpened():
49
+ ret, frame = cap.read()
50
+ if not ret:
51
+ break
52
+
53
+ resized_frame = cv2.resize(frame, (640, 640))
54
+ result_frame = perform_detection(resized_frame, interpreter, labels, input_details, output_details, height, width, floating_model)
55
+ frames.append(result_frame)
56
+
57
+ cap.release()
58
+
59
+ if not frames:
60
+ raise ValueError("No frames were read from the video.")
61
+
62
+ height, width, layers = frames[0].shape
63
+ size = (width, height)
64
+ output_video_path = "result_" + os.path.basename(input_video)
65
+ out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 15, size)
66
+
67
+ for frame in frames:
68
+ out.write(frame)
69
+
70
+ out.release()
71
+
72
+ return output_video_path