hiwei commited on
Commit
78d1a78
1 Parent(s): 315ae29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -5
app.py CHANGED
@@ -8,12 +8,23 @@ TITLE = "Handwritten Digit Recognition Demo"
8
  DESCRIPTION = "This demo employs a basic CNN architecture inspired by [MIT 6.S191’s Lab2 Part1](https://github.com/aamini/introtodeeplearning/blob/master/lab2/Part1_MNIST.ipynb). "\
9
  "It achieves about 98% accuracy on the MNIST test dataset but may perform poorly, particularly with digits 8 and 9, likely due to suboptimal image preprocessing."
10
 
 
 
 
 
 
 
 
 
 
 
 
11
  model = tf.keras.saving.load_model("tf_model_mnist")
12
 
13
 
14
- def preprocess(image):
15
  """ Normalize Gradio image to MNIST format """
16
- image = image.resize((28, 28), Image.Resampling.BOX)
17
  img_array = np.asarray(image, dtype=np.float32)
18
  for i in range(img_array.shape[0]):
19
  for j in range(img_array.shape[1]):
@@ -30,14 +41,19 @@ def preprocess(image):
30
  return image_array, new_image
31
 
32
 
33
- def predict(img):
34
  img = img["composite"]
35
- input_arr, new_image = preprocess(img)
36
  print("input:", input_arr.shape)
37
  predictions = model.predict(input_arr)
38
  return {str(i): predictions[0][i] for i in range(10)}, new_image
39
 
40
 
 
 
 
 
 
41
  input_image = gr.Sketchpad(
42
  layers=False,
43
  type="pil",
@@ -47,7 +63,7 @@ demo = gr.Interface(
47
  predict,
48
  title=TITLE,
49
  description=DESCRIPTION,
50
- inputs=input_image,
51
  outputs=['label', 'image']
52
  )
53
 
 
8
  DESCRIPTION = "This demo employs a basic CNN architecture inspired by [MIT 6.S191’s Lab2 Part1](https://github.com/aamini/introtodeeplearning/blob/master/lab2/Part1_MNIST.ipynb). "\
9
  "It achieves about 98% accuracy on the MNIST test dataset but may perform poorly, particularly with digits 8 and 9, likely due to suboptimal image preprocessing."
10
 
11
+
12
+ PIL_INTERPOLATION_METHODS = {
13
+ "nearest": Image.Resampling.NEAREST,
14
+ "bilinear": Image.Resampling.BILINEAR,
15
+ "bicubic": Image.Resampling.BICUBIC,
16
+ "hamming": Image.Resampling.HAMMING,
17
+ "box": Image.Resampling.BOX,
18
+ "lanczos": Image.Resampling.LANCZOS,
19
+ }
20
+
21
+
22
  model = tf.keras.saving.load_model("tf_model_mnist")
23
 
24
 
25
+ def preprocess(image, resample_method):
26
  """ Normalize Gradio image to MNIST format """
27
+ image = image.resize((28, 28), PIL_INTERPOLATION_METHODS[resample_method])
28
  img_array = np.asarray(image, dtype=np.float32)
29
  for i in range(img_array.shape[0]):
30
  for j in range(img_array.shape[1]):
 
41
  return image_array, new_image
42
 
43
 
44
+ def predict(img, resample_method):
45
  img = img["composite"]
46
+ input_arr, new_image = preprocess(img, resample_method)
47
  print("input:", input_arr.shape)
48
  predictions = model.predict(input_arr)
49
  return {str(i): predictions[0][i] for i in range(10)}, new_image
50
 
51
 
52
+ resample_method = gr.Dropdown(
53
+ choices=list(PIL_INTERPOLATION_METHODS.keys()),
54
+ value='bilinear',
55
+ )
56
+
57
  input_image = gr.Sketchpad(
58
  layers=False,
59
  type="pil",
 
63
  predict,
64
  title=TITLE,
65
  description=DESCRIPTION,
66
+ inputs=[input_image, resample_method],
67
  outputs=['label', 'image']
68
  )
69