brxerq commited on
Commit
9efab48
1 Parent(s): fb51722

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -22
app.py CHANGED
@@ -1,31 +1,102 @@
1
  # app.py
 
 
 
 
 
2
  import gradio as gr
3
- import importlib
4
 
5
- def load_model(model_name):
6
- module = importlib.import_module(model_name)
7
- return module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  app = gr.Blocks()
10
 
11
  with app:
12
  gr.Markdown("## Object Detection using TensorFlow Lite Models")
13
  with gr.Row():
14
- model_choice = gr.Dropdown(label="Select Model", choices=["model_1", "model_2", "model_3"])
15
- image_input = gr.Image(type="pil", label="Upload an image")
16
- image_output = gr.Image(type="pil", label="Detection Result")
17
- video_input = gr.Video(label="Upload a video")
18
- video_output = gr.Video(label="Detection Result")
19
-
20
- def image_detection(model_name, input_image):
21
- model = load_model(model_name)
22
- return model.detect_image(input_image)
23
-
24
- def video_detection(model_name, input_video):
25
- model = load_model(model_name)
26
- return model.detect_video(input_video)
27
-
28
- gr.Button("Submit Image").click(fn=image_detection, inputs=[model_choice, image_input], outputs=image_output)
29
- gr.Button("Submit Video").click(fn=video_detection, inputs=[model_choice, video_input], outputs=video_output)
30
-
31
- app.launch()
 
1
  # app.py
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ import importlib.util
6
+ from PIL import Image
7
  import gradio as gr
8
+ from common_detection import perform_detection, resize_image
9
 
10
+ # Function to load the TensorFlow Lite model and labels
11
+ def load_model_and_labels(model_dir):
12
+ pkg = importlib.util.find_spec('tflite_runtime')
13
+ if pkg:
14
+ from tflite_runtime.interpreter import Interpreter
15
+ else:
16
+ from tensorflow.lite.python.interpreter import Interpreter
17
+
18
+ PATH_TO_CKPT = os.path.join(model_dir, 'detect.tflite')
19
+ PATH_TO_LABELS = os.path.join(model_dir, 'labelmap.txt')
20
+
21
+ with open(PATH_TO_LABELS, 'r') as f:
22
+ labels = [line.strip() for line in f.readlines()]
23
+
24
+ if labels[0] == '???':
25
+ del(labels[0])
26
+
27
+ interpreter = Interpreter(model_path=PATH_TO_CKPT)
28
+ interpreter.allocate_tensors()
29
+
30
+ input_details = interpreter.get_input_details()
31
+ output_details = interpreter.get_output_details()
32
+ height = input_details[0]['shape'][1]
33
+ width = input_details[0]['shape'][2]
34
+ floating_model = (input_details[0]['dtype'] == np.float32)
35
+
36
+ return interpreter, labels, input_details, output_details, height, width, floating_model
37
+
38
+ # Load models
39
+ models = {
40
+ "Multi-class model": "model",
41
+ "Empty class": "model_2",
42
+ "Misalignment class": "model_3"
43
+ }
44
+
45
+ # Function to perform image detection
46
+ def detect_image(model_choice, input_image):
47
+ model_dir = models[model_choice]
48
+ interpreter, labels, input_details, output_details, height, width, floating_model = load_model_and_labels(model_dir)
49
+ image = np.array(input_image)
50
+ resized_image = resize_image(image, size=640)
51
+ result_image = perform_detection(resized_image, interpreter, labels, input_details, output_details, height, width, floating_model)
52
+ return Image.fromarray(result_image)
53
+
54
+ # Function to perform video detection
55
+ def detect_video(model_choice, input_video):
56
+ model_dir = models[model_choice]
57
+ interpreter, labels, input_details, output_details, height, width, floating_model = load_model_and_labels(model_dir)
58
+ cap = cv2.VideoCapture(input_video)
59
+ frames = []
60
+
61
+ while cap.isOpened():
62
+ ret, frame = cap.read()
63
+ if not ret:
64
+ break
65
+
66
+ resized_frame = resize_image(frame, size=640)
67
+ result_frame = perform_detection(resized_frame, interpreter, labels, input_details, output_details, height, width, floating_model)
68
+ frames.append(result_frame)
69
+
70
+ cap.release()
71
+
72
+ if not frames:
73
+ raise ValueError("No frames were read from the video.")
74
+
75
+ height, width, layers = frames[0].shape
76
+ size = (width, height)
77
+ output_video_path = "result_" + os.path.basename(input_video)
78
+ out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 15, size)
79
+
80
+ for frame in frames:
81
+ out.write(frame)
82
+
83
+ out.release()
84
+
85
+ return output_video_path
86
 
87
  app = gr.Blocks()
88
 
89
  with app:
90
  gr.Markdown("## Object Detection using TensorFlow Lite Models")
91
  with gr.Row():
92
+ model_choice = gr.Dropdown(label="Select Model", choices=["Multi-class model", "Empty class", "Misalignment class"])
93
+ with gr.Tab("Image Detection"):
94
+ image_input = gr.Image(type="pil", label="Upload an image")
95
+ image_output = gr.Image(type="pil", label="Detection Result")
96
+ gr.Button("Submit Image").click(fn=detect_image, inputs=[model_choice, image_input], outputs=image_output)
97
+ with gr.Tab("Video Detection"):
98
+ video_input = gr.Video(label="Upload a video")
99
+ video_output = gr.Video(label="Detection Result")
100
+ gr.Button("Submit Video").click(fn=detect_video, inputs=[model_choice, video_input], outputs=video_output)
101
+
102
+ app.launch()