brxerq commited on
Commit
7c3c7c7
1 Parent(s): 82d596d

Rename app.py to main.py

Browse files
Files changed (2) hide show
  1. app.py +0 -101
  2. main.py +43 -0
app.py DELETED
@@ -1,101 +0,0 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import importlib.util
5
- from PIL import Image
6
- import gradio as gr
7
- from common_detection import perform_detection, resize_image
8
-
9
- # Function to load the TensorFlow Lite model and labels
10
- def load_model_and_labels(model_dir):
11
- pkg = importlib.util.find_spec('tflite_runtime')
12
- if pkg:
13
- from tflite_runtime.interpreter import Interpreter
14
- else:
15
- from tensorflow.lite.python.interpreter import Interpreter
16
-
17
- PATH_TO_CKPT = os.path.join(model_dir, 'detect.tflite')
18
- PATH_TO_LABELS = os.path.join(model_dir, 'labelmap.txt')
19
-
20
- with open(PATH_TO_LABELS, 'r') as f:
21
- labels = [line.strip() for line in f.readlines()]
22
-
23
- if labels[0] == '???':
24
- del(labels[0])
25
-
26
- interpreter = Interpreter(model_path=PATH_TO_CKPT)
27
- interpreter.allocate_tensors()
28
-
29
- input_details = interpreter.get_input_details()
30
- output_details = interpreter.get_output_details()
31
- height = input_details[0]['shape'][1]
32
- width = input_details[0]['shape'][2]
33
- floating_model = (input_details[0]['dtype'] == np.float32)
34
-
35
- return interpreter, labels, input_details, output_details, height, width, floating_model
36
-
37
- # Load models
38
- models = {
39
- "Multi-class model": "model",
40
- "Empty class": "model_2",
41
- "Misalignment class": "model_3"
42
- }
43
-
44
- # Function to perform image detection
45
- def detect_image(model_choice, input_image):
46
- model_dir = models[model_choice]
47
- interpreter, labels, input_details, output_details, height, width, floating_model = load_model_and_labels(model_dir)
48
- image = np.array(input_image)
49
- resized_image = resize_image(image, size=640)
50
- result_image = perform_detection(resized_image, interpreter, labels, input_details, output_details, height, width, floating_model)
51
- return Image.fromarray(result_image)
52
-
53
- # Function to perform video detection
54
- def detect_video(model_choice, input_video):
55
- model_dir = models[model_choice]
56
- interpreter, labels, input_details, output_details, height, width, floating_model = load_model_and_labels(model_dir)
57
- cap = cv2.VideoCapture(input_video)
58
- frames = []
59
-
60
- while cap.isOpened():
61
- ret, frame = cap.read()
62
- if not ret:
63
- break
64
-
65
- resized_frame = resize_image(frame, size=640)
66
- result_frame = perform_detection(resized_frame, interpreter, labels, input_details, output_details, height, width, floating_model)
67
- frames.append(result_frame)
68
-
69
- cap.release()
70
-
71
- if not frames:
72
- raise ValueError("No frames were read from the video.")
73
-
74
- height, width, layers = frames[0].shape
75
- size = (width, height)
76
- output_video_path = "result_" + os.path.basename(input_video)
77
- out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 15, size)
78
-
79
- for frame in frames:
80
- out.write(frame)
81
-
82
- out.release()
83
-
84
- return output_video_path
85
-
86
- app = gr.Blocks()
87
-
88
- with app:
89
- gr.Markdown("## Object Detection using TensorFlow Lite Models")
90
- with gr.Row():
91
- model_choice = gr.Dropdown(label="Select Model", choices=["Multi-class model", "Empty class", "Misalignment class"])
92
- with gr.Tab("Image Detection"):
93
- image_input = gr.Image(type="pil", label="Upload an image")
94
- image_output = gr.Image(type="pil", label="Detection Result")
95
- gr.Button("Submit Image").click(fn=detect_image, inputs=[model_choice, image_input], outputs=image_output)
96
- with gr.Tab("Video Detection"):
97
- video_input = gr.Video(label="Upload a video")
98
- video_output = gr.Video(label="Detection Result")
99
- gr.Button("Submit Video").click(fn=detect_video, inputs=[model_choice, video_input], outputs=video_output)
100
-
101
- app.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import importlib
3
+ import gradio as gr
4
+
5
+ # Function to load the model dynamically
6
+ def load_model(model_name):
7
+ module = importlib.import_module(model_name)
8
+ return module
9
+
10
+ # Define the models
11
+ models = {
12
+ "Multi-class model": "model_1",
13
+ "Empty class": "model_2",
14
+ "Misalignment class": "model_3"
15
+ }
16
+
17
+ # Function to perform image detection
18
+ def detect_image(model_choice, input_image):
19
+ model = load_model(models[model_choice])
20
+ return model.detect_image(input_image)
21
+
22
+ # Function to perform video detection
23
+ def detect_video(model_choice, input_video):
24
+ model = load_model(models[model_choice])
25
+ return model.detect_video(input_video)
26
+
27
+ # Create Gradio app
28
+ app = gr.Blocks()
29
+
30
+ with app:
31
+ gr.Markdown("## Object Detection using TensorFlow Lite Models")
32
+ with gr.Row():
33
+ model_choice = gr.Dropdown(label="Select Model", choices=list(models.keys()))
34
+ with gr.Tab("Image Detection"):
35
+ image_input = gr.Image(type="pil", label="Upload an image")
36
+ image_output = gr.Image(type="pil", label="Detection Result")
37
+ gr.Button("Submit Image").click(fn=detect_image, inputs=[model_choice, image_input], outputs=image_output)
38
+ with gr.Tab("Video Detection"):
39
+ video_input = gr.Video(label="Upload a video")
40
+ video_output = gr.Video(label="Detection Result")
41
+ gr.Button("Submit Video").click(fn=detect_video, inputs=[model_choice, video_input], outputs=video_output)
42
+
43
+ app.launch(share=True)