amosfang commited on
Commit
81c8e53
1 Parent(s): 20ffaff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from PIL import Image
4
+ from skimage.transform import resize
5
+ import tensorflow as tf
6
+ from tensorflow.keras.models import load_model
7
+
8
+ import gradio as gr
9
+ import os
10
+
11
+ REPO_ID = "amosfang/segmentation_u_net"
12
+
13
+ def pil_image_as_numpy_array(pilimg):
14
+ img_array = tf.keras.utils.img_to_array(pilimg)
15
+ return img_array
16
+
17
+ def resize_image(image, input_shape=(224, 224, 3)):
18
+ # Convert to NumPy array and normalize
19
+ image_array = pil_image_as_numpy_array(image)
20
+ image = image_array.astype(np.float32) / 255.
21
+
22
+ # Resize the image to 224x224
23
+ image_resized = resize(image, input_shape, anti_aliasing=True)
24
+
25
+ return image_resized
26
+
27
+ def load_model():
28
+ model_dir = snapshot_download(REPO_ID)
29
+ # saved_model_dir = os.path.join(download_dir, "saved_model")
30
+ unet_model = load_model(model_dir)
31
+ return unet_model
32
+
33
+ def ensemble_predict(X_array):
34
+ #
35
+ # Call the predict methods of the unet_model and the vgg16_unet_model
36
+ # to retrieve their predictions.
37
+ #
38
+ # Sum the two predictions together and return their results.
39
+ # You can also consider multiplying a different weight on
40
+ # one or both of the models to improve performance
41
+
42
+ X_array = np.expand_dims(X_array, axis=0)
43
+
44
+ unet_model = load_model('REPO_ID/train_2024-02-14 11-20-17/base_u_net.0098-acc-0.75-val_acc-0.74-loss-0.79.h5')
45
+ vgg16_model = load_model('REPO_ID/vgg16_u_net.0092-acc-0.74-val_acc-0.74-loss-0.82.h5')
46
+ resnet50_model = load_model('REPO_ID/resnet50_u_net.0095-acc-0.79-val_acc-0.76-loss-0.72.h5')
47
+
48
+ pred_y_unet = unet_model.predict(X_array)
49
+ pred_y_vgg16 = vgg16_model.predict(X_array)
50
+ pred_y_resnet50 = resnet50_model.predict(X_array)
51
+
52
+ return (pred_y_unet + pred_y_vgg16 + pred_y_resnet50) / 3
53
+
54
+ def get_predictions(y_prediction_encoded):
55
+
56
+ # Convert predictions to categorical indices
57
+ predicted_label_indices = np.argmax(y_prediction_encoded, axis=-1) + 1
58
+
59
+ return predicted_label_indices
60
+
61
+ def predict(image):
62
+ sample_image_resized = resize_image(image, input_shape)
63
+ y_pred = ensemble_predict(sample_image_resized)
64
+ y_pred = get_predictions(y_pred).squeeze()
65
+
66
+ # Create a figure without saving it to a file
67
+ fig, ax = plt.subplots()
68
+ cax = ax.imshow(y_pred, cmap='viridis', vmin=1, vmax=7)
69
+
70
+ # Convert the figure to a PIL Image
71
+ image_buffer = io.BytesIO()
72
+ plt.savefig(image_buffer, format='png')
73
+ image_buffer.seek(0)
74
+ image_pil = Image.open(image_buffer)
75
+
76
+ # Close the figure to release resources
77
+ plt.close(fig)
78
+
79
+ return image_pil
80
+
81
+ # Specify paths to example images
82
+ sample_images = [['989953_sat.jpg'], ['999380_sat.jpg'], ['988205_sat.jpg']]
83
+
84
+ # Launch Gradio Interface
85
+ gr.Interface(
86
+ predict,
87
+ title='Land Cover Segmentation',
88
+ inputs=[gr.Image()],
89
+ outputs=[gr.Image()],
90
+ examples=sample_images
91
+ ).launch(debug=True, share=True)
92
+
93
+ # Launch the interface
94
+ iface.launch(share=True)