noahzhy commited on
Commit
eb74463
1 Parent(s): 1b4374d

chore: Update TFliteDemo constructor parameters, model.tflite file, and preprocessing methods

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -7,7 +7,7 @@ import gradio as gr
7
  import tensorflow as tf
8
 
9
 
10
- def get_sample_images():
11
  list_ = glob.glob(os.path.join(os.path.dirname(__file__), 'samples/*.jpg'))
12
  # sort by name
13
  list_.sort(key=lambda x: int(x.split('/')[-1].split('.')[0]))
@@ -40,6 +40,7 @@ def center_fit(img, w, h, inter=cv2.INTER_NEAREST, top_left=True):
40
  inter = cv2.INTER_AREA
41
  # resize img
42
  img = cv2.resize(img, (int(img_w * ratio), int(img_h * ratio)), interpolation=inter)
 
43
  # get new img shape
44
  img_h, img_w = img.shape[:2]
45
  # get start point
@@ -69,9 +70,8 @@ def load_dict(dict_path='label.names'):
69
 
70
 
71
  class TFliteDemo:
72
- def __init__(self, model_path, blank=0, conf_mode="min"):
73
  self.blank = blank
74
- self.conf_mode = conf_mode
75
  self.interpreter = tf.lite.Interpreter(model_path=model_path)
76
  self.interpreter.allocate_tensors()
77
  self.inputs = self.interpreter.get_input_details()
@@ -90,30 +90,28 @@ class TFliteDemo:
90
  if img is None:
91
  raise ValueError('img is None')
92
  image = img.copy()
93
- image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
94
  image = center_fit(image, 192, 96, top_left=True)
95
  image = np.reshape(image, (1, *image.shape, 1)).astype(np.float32) / 255.0
96
  return image
97
 
98
- def get_confidence(self, pred, mode="mean"):
99
  _argmax = np.argmax(pred, axis=-1)
100
  _idx = _argmax != pred.shape[-1] - 1
101
  conf = pred[_idx, _argmax[_idx]]
102
- conf = np.exp(conf)
103
- return np.min(conf) if mode == "min" else np.mean(conf)
104
 
105
  def postprocess(self, pred):
106
  label = decode_label(pred, load_dict())
107
- conf = self.get_confidence(pred[0], mode=self.conf_mode)
108
  # keep 4 decimal places
109
  conf = float('{:.4f}'.format(conf))
110
  return label, conf
111
 
112
-
113
- def inference(img):
114
- img = demo.preprocess(img)
115
- pred = demo.inference(img)
116
- return demo.postprocess(pred)
117
 
118
 
119
  if __name__ == '__main__':
@@ -127,8 +125,8 @@ if __name__ == '__main__':
127
  '''
128
  # init model
129
  demo = TFliteDemo(os.path.join(os.path.dirname(__file__), 'model.tflite'))
130
- interface = gr.Interface(
131
- fn=inference,
132
  inputs="image",
133
  outputs=[
134
  gr.Textbox(label="Plate Number", type="text"),
@@ -136,6 +134,6 @@ if __name__ == '__main__':
136
  ],
137
  title=_TITLE,
138
  description=_DESCRIPTION,
139
- examples=get_sample_images(),
140
  )
141
- interface.launch()
 
7
  import tensorflow as tf
8
 
9
 
10
+ def get_samples():
11
  list_ = glob.glob(os.path.join(os.path.dirname(__file__), 'samples/*.jpg'))
12
  # sort by name
13
  list_.sort(key=lambda x: int(x.split('/')[-1].split('.')[0]))
 
40
  inter = cv2.INTER_AREA
41
  # resize img
42
  img = cv2.resize(img, (int(img_w * ratio), int(img_h * ratio)), interpolation=inter)
43
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
44
  # get new img shape
45
  img_h, img_w = img.shape[:2]
46
  # get start point
 
70
 
71
 
72
  class TFliteDemo:
73
+ def __init__(self, model_path, blank=0):
74
  self.blank = blank
 
75
  self.interpreter = tf.lite.Interpreter(model_path=model_path)
76
  self.interpreter.allocate_tensors()
77
  self.inputs = self.interpreter.get_input_details()
 
90
  if img is None:
91
  raise ValueError('img is None')
92
  image = img.copy()
93
+
94
  image = center_fit(image, 192, 96, top_left=True)
95
  image = np.reshape(image, (1, *image.shape, 1)).astype(np.float32) / 255.0
96
  return image
97
 
98
+ def get_confidence(self, pred):
99
  _argmax = np.argmax(pred, axis=-1)
100
  _idx = _argmax != pred.shape[-1] - 1
101
  conf = pred[_idx, _argmax[_idx]]
102
+ return np.min(np.exp(conf))
 
103
 
104
  def postprocess(self, pred):
105
  label = decode_label(pred, load_dict())
106
+ conf = self.get_confidence(pred[0])
107
  # keep 4 decimal places
108
  conf = float('{:.4f}'.format(conf))
109
  return label, conf
110
 
111
+ def get_results(self, img):
112
+ img = self.preprocess(img)
113
+ pred = self.inference(img)
114
+ return self.postprocess(pred)
 
115
 
116
 
117
  if __name__ == '__main__':
 
125
  '''
126
  # init model
127
  demo = TFliteDemo(os.path.join(os.path.dirname(__file__), 'model.tflite'))
128
+ app = gr.Interface(
129
+ fn=demo.get_results,
130
  inputs="image",
131
  outputs=[
132
  gr.Textbox(label="Plate Number", type="text"),
 
134
  ],
135
  title=_TITLE,
136
  description=_DESCRIPTION,
137
+ examples=get_samples(),
138
  )
139
+ app.launch()