yuragoithf commited on
Commit
834fb23
1 Parent(s): f6f9242

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -26
app.py CHANGED
@@ -33,43 +33,42 @@ model_file = download_model()
33
  model = tf.keras.models.load_model(model_file)
34
 
35
  # Perform image classification for single class output
36
- def predict_class(image):
37
- img = tf.cast(image, tf.float32)
38
- img = tf.image.resize(img, [input_shape[0], input_shape[1]])
39
- img = tf.expand_dims(img, axis=0)
40
- prediction = model.predict(img)
41
- class_index = tf.argmax(prediction[0]).numpy()
42
- predicted_class = labels[class_index]
43
- print("predicted_class is ", predicted_class)####################################################
44
- return predicted_class
45
-
46
- # Perform image classification for multy class output
47
  # def predict_class(image):
48
  # img = tf.cast(image, tf.float32)
49
  # img = tf.image.resize(img, [input_shape[0], input_shape[1]])
50
  # img = tf.expand_dims(img, axis=0)
51
  # prediction = model.predict(img)
52
- # return prediction[0]
 
 
 
 
 
 
 
 
 
 
53
 
54
  # UI Design for single class output
55
- def classify_image(image):
56
- predicted_class = predict_class(image)
57
- output = f"<h2>Predicted Class: <span style='text-transform:uppercase';>{predicted_class}</span></h2>"
58
- return output
59
 
60
 
61
  # UI Design for multy class output
62
- # def classify_image(image):
63
- # results = predict_class(image)
64
- # print(results)
65
- # output = {labels.get(i): float(results[i]) for i in range(len(results))}
66
- # result = output if max(output.values()) >=0.98 else {"NO_CIFAR10_CLASS": 1}
67
- # return result
68
 
69
 
70
  inputs = gr.inputs.Image(type="pil", label="Upload an image")
71
- outputs = gr.outputs.HTML() #uncomment for single class output
72
- #outputs = gr.outputs.Label(num_top_classes=4)
73
 
74
  title = "<h1 style='text-align: center;'>Image Classifier</h1>"
75
  description = "Upload an image and get the predicted class."
@@ -81,5 +80,4 @@ gr.Interface(fn=classify_image,
81
  title=title,
82
  examples=[["00_plane.jpg"], ["01_car.jpg"], ["02_house.jpg"], ["03_cat.jpg"], ["04_deer.jpg"]],
83
  # css=css_code,
84
- description=description,
85
- enable_queue=True).launch()
 
33
  model = tf.keras.models.load_model(model_file)
34
 
35
  # Perform image classification for single class output
 
 
 
 
 
 
 
 
 
 
 
36
  # def predict_class(image):
37
  # img = tf.cast(image, tf.float32)
38
  # img = tf.image.resize(img, [input_shape[0], input_shape[1]])
39
  # img = tf.expand_dims(img, axis=0)
40
  # prediction = model.predict(img)
41
+ # class_index = tf.argmax(prediction[0]).numpy()
42
+ # predicted_class = labels[class_index]
43
+ # return predicted_class
44
+
45
+ # Perform image classification for multy class output
46
+ def predict_class(image):
47
+ img = tf.cast(image, tf.float32)
48
+ img = tf.image.resize(img, [input_shape[0], input_shape[1]])
49
+ img = tf.expand_dims(img, axis=0)
50
+ prediction = model.predict(img)
51
+ return prediction[0]
52
 
53
  # UI Design for single class output
54
+ # def classify_image(image):
55
+ # predicted_class = predict_class(image)
56
+ # output = f"<h2>Predicted Class: <span style='text-transform:uppercase';>{predicted_class}</span></h2>"
57
+ # return output
58
 
59
 
60
  # UI Design for multy class output
61
+ def classify_image(image):
62
+ results = predict_class(image)
63
+ print(results)
64
+ output = {labels.get(i): float(results[i]) for i in range(len(results))}
65
+ result = output if max(output.values()) >=0.98 else {"NO_CIFAR10_CLASS": 1}
66
+ return result
67
 
68
 
69
  inputs = gr.inputs.Image(type="pil", label="Upload an image")
70
+ # outputs = gr.outputs.HTML() #uncomment for single class output
71
+ outputs = gr.outputs.Label(num_top_classes=4)
72
 
73
  title = "<h1 style='text-align: center;'>Image Classifier</h1>"
74
  description = "Upload an image and get the predicted class."
 
80
  title=title,
81
  examples=[["00_plane.jpg"], ["01_car.jpg"], ["02_house.jpg"], ["03_cat.jpg"], ["04_deer.jpg"]],
82
  # css=css_code,
83
+ description=description).launch()