brxerq commited on
Commit
e33a4ef
1 Parent(s): 683ca2f

Rename app_model_3.py to model_3.py

Browse files
Files changed (2) hide show
  1. app_model_3.py +0 -102
  2. model_3.py +23 -0
app_model_3.py DELETED
@@ -1,102 +0,0 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import importlib.util
5
- import gradio as gr
6
- from PIL import Image
7
-
8
- # Load the TensorFlow Lite model for Model 1
9
- MODEL_DIR = 'model_3'
10
- GRAPH_NAME = 'detect.tflite'
11
- LABELMAP_NAME = 'labelmap.txt'
12
-
13
- pkg = importlib.util.find_spec('tflite_runtime')
14
- if pkg:
15
- from tflite_runtime.interpreter import Interpreter
16
- from tflite_runtime.interpreter import load_delegate
17
- else:
18
- from tensorflow.lite.python.interpreter import Interpreter
19
- from tensorflow.lite.python.interpreter import load_delegate
20
-
21
- PATH_TO_CKPT = os.path.join(MODEL_DIR, GRAPH_NAME)
22
- PATH_TO_LABELS = os.path.join(MODEL_DIR, LABELMAP_NAME)
23
-
24
- # Load the label map
25
- with open(PATH_TO_LABELS, 'r') as f:
26
- labels = [line.strip() for line in f.readlines()]
27
-
28
- if labels[0] == '???':
29
- del(labels[0])
30
-
31
- # Load the TensorFlow Lite model
32
- interpreter = Interpreter(model_path=PATH_TO_CKPT)
33
- interpreter.allocate_tensors()
34
-
35
- input_details = interpreter.get_input_details()
36
- output_details = interpreter.get_output_details()
37
- height = input_details[0]['shape'][1]
38
- width = input_details[0]['shape'][2]
39
- floating_model = (input_details[0]['dtype'] == np.float32)
40
-
41
- input_mean = 127.5
42
- input_std = 127.5
43
-
44
- outname = output_details[0]['name']
45
- if ('StatefulPartitionedCall' in outname):
46
- boxes_idx, classes_idx, scores_idx = 1, 3, 0
47
- else:
48
- boxes_idx, classes_idx, scores_idx = 0, 1, 2
49
-
50
- def perform_detection(image, interpreter, labels):
51
- imH, imW, _ = image.shape
52
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
53
- image_resized = cv2.resize(image_rgb, (width, height))
54
- input_data = np.expand_dims(image_resized, axis=0)
55
-
56
- if floating_model:
57
- input_data = (np.float32(input_data) - input_mean) / input_std
58
-
59
- interpreter.set_tensor(input_details[0]['index'], input_data)
60
- interpreter.invoke()
61
-
62
- boxes = interpreter.get_tensor(output_details[boxes_idx]['index'])[0]
63
- classes = interpreter.get_tensor(output_details[classes_idx]['index'])[0]
64
- scores = interpreter.get_tensor(output_details[scores_idx]['index'])[0]
65
-
66
- detections = []
67
- for i in range(len(scores)):
68
- if ((scores[i] > 0.5) and (scores[i] <= 1.0)):
69
- ymin = int(max(1, (boxes[i][0] * imH)))
70
- xmin = int(max(1, (boxes[i][1] * imW)))
71
- ymax = int(min(imH, (boxes[i][2] * imH)))
72
- xmax = int(min(imW, (boxes[i][3] * imW)))
73
-
74
- cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (10, 255, 0), 2)
75
- object_name = labels[int(classes[i])]
76
- label = '%s: %d%%' % (object_name, int(scores[i] * 100))
77
- labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
78
- label_ymin = max(ymin, labelSize[1] + 10)
79
- cv2.rectangle(image, (xmin, label_ymin - labelSize[1] - 10), (xmin + labelSize[0], label_ymin + baseLine - 10), (255, 255, 255), cv2.FILLED)
80
- cv2.putText(image, label, (xmin, label_ymin - 7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
81
-
82
- detections.append([object_name, scores[i], xmin, ymin, xmax, ymax])
83
- return image
84
-
85
- def resize_image(image, size=640):
86
- return cv2.resize(image, (size, size))
87
-
88
- def detect_image(input_image):
89
- image = np.array(input_image)
90
- resized_image = resize_image(image, size=640) # Resize input image
91
- result_image = perform_detection(resized_image, interpreter, labels)
92
- return Image.fromarray(result_image)
93
-
94
- app = gr.Interface(
95
- detect_image,
96
- inputs=gr.inputs.Image(type="pil", label="Upload an image"),
97
- outputs="image",
98
- title="Model 3:Misalignment-class Object Detection",
99
- theme="compact"
100
- )
101
-
102
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_3.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_3.py
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ import importlib.util
6
+ from PIL import Image
7
+ import gradio as gr
8
+ from common_detection import perform_detection
9
+
10
+ MODEL_DIR = 'model_3'
11
+ GRAPH_NAME = 'detect.tflite'
12
+ LABELMAP_NAME = 'labelmap.txt'
13
+
14
+ pkg = importlib.util.find_spec('tflite_runtime')
15
+ if pkg:
16
+ from tflite_runtime.interpreter import Interpreter
17
+ else:
18
+ from tensorflow.lite.python.interpreter import Interpreter
19
+
20
+ PATH_TO_CKPT = os.path.join(MODEL_DIR, GRAPH_NAME)
21
+ PATH_TO_LABELS = os.path.join(MODEL_DIR, LABELMAP_NAME)
22
+
23
+ with open(PATH_TO_LABELS, '