Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,9 @@ import os
|
|
14 |
import io
|
15 |
|
16 |
REPO_ID = "amosfang/segmentation_u_net"
|
|
|
|
|
|
|
17 |
|
18 |
def pil_image_as_numpy_array(pilimg):
|
19 |
img_array = tf.keras.utils.img_to_array(pilimg)
|
@@ -39,13 +42,26 @@ def get_sample_images(image_folder, format=('.jpg', '.jpeg')):
|
|
39 |
|
40 |
# Get a list of all files in the folder
|
41 |
img_file_list = os.listdir(image_folder)
|
42 |
-
img_file_list.sort()
|
43 |
|
44 |
-
# Filter out only the image files (assuming images have extensions like '.jpg'
|
45 |
image_files = [[image_folder +'/' + file] for file in img_file_list if file.lower().endswith(format)]
|
46 |
|
47 |
return image_files
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def ensemble_predict(X_array):
|
50 |
#
|
51 |
# Call the predict methods of the unet_model and the vgg16_unet_model
|
@@ -75,8 +91,12 @@ def get_predictions(y_prediction_encoded):
|
|
75 |
return predicted_label_indices
|
76 |
|
77 |
def predict_on_train(image):
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
# Steps to get prediction
|
80 |
sample_image_resized = resize_image(image)
|
81 |
y_pred = ensemble_predict(sample_image_resized)
|
82 |
y_pred = get_predictions(y_pred).squeeze()
|
@@ -93,10 +113,10 @@ def predict_on_train(image):
|
|
93 |
ax.imshow(sample_image_resized)
|
94 |
|
95 |
# Display the predictions using the specified colormap
|
96 |
-
cax = ax.imshow(y_pred, cmap=cmap, vmin=1, vmax=
|
97 |
|
98 |
# Create colorbar and set ticks and ticklabels
|
99 |
-
cbar = plt.colorbar(cax, ticks=np.arange(1,
|
100 |
cbar.set_ticklabels(['Urban', 'Agriculture', 'Range Land', 'Forest', 'Water', 'Barren', 'Unknown'])
|
101 |
|
102 |
# Convert the figure to a PIL Image
|
@@ -108,7 +128,7 @@ def predict_on_train(image):
|
|
108 |
# Close the figure to release resources
|
109 |
plt.close(fig)
|
110 |
|
111 |
-
return
|
112 |
|
113 |
def predict_on_test(image):
|
114 |
|
@@ -147,8 +167,8 @@ def predict_on_test(image):
|
|
147 |
return image_pil
|
148 |
|
149 |
|
150 |
-
|
151 |
-
|
152 |
|
153 |
description= '''
|
154 |
The DeepGlobe Land Cover Classification Challenge offers the first public dataset containing high resolution
|
|
|
14 |
import io
|
15 |
|
16 |
REPO_ID = "amosfang/segmentation_u_net"
|
17 |
+
TRAIN_FOLDER = 'train_images'
|
18 |
+
TEST_FOLDER = 'example_images'
|
19 |
+
NUM_CLASSES = 7
|
20 |
|
21 |
def pil_image_as_numpy_array(pilimg):
|
22 |
img_array = tf.keras.utils.img_to_array(pilimg)
|
|
|
42 |
|
43 |
# Get a list of all files in the folder
|
44 |
img_file_list = os.listdir(image_folder)
|
|
|
45 |
|
46 |
+
# Filter out only the image files (assuming images have extensions like '.jpg')
|
47 |
image_files = [[image_folder +'/' + file] for file in img_file_list if file.lower().endswith(format)]
|
48 |
|
49 |
return image_files
|
50 |
|
51 |
+
def get_sample_mask(image_folder, image):
|
52 |
+
|
53 |
+
# Get the filename of the original image
|
54 |
+
image_filename = os.path.basename(image.name)
|
55 |
+
|
56 |
+
# Construct the filename for the ground truth mask
|
57 |
+
mask_filename = image_filename.replace('_sat.jpg', '_mask.png')
|
58 |
+
|
59 |
+
# Load the ground truth mask
|
60 |
+
mask_path = os.path.join(image_folder, mask_filename)
|
61 |
+
ground_truth_mask = Image.open(mask_path)
|
62 |
+
|
63 |
+
return ground_truth_mask
|
64 |
+
|
65 |
def ensemble_predict(X_array):
|
66 |
#
|
67 |
# Call the predict methods of the unet_model and the vgg16_unet_model
|
|
|
91 |
return predicted_label_indices
|
92 |
|
93 |
def predict_on_train(image):
|
94 |
+
|
95 |
+
# Steps to get the ground truth image mask
|
96 |
+
ground_truth_mask = get_sample_mask(TRAIN_FOLDER, image)
|
97 |
+
ground_truth_mask_pil = resize_image(ground_truth_mask)
|
98 |
|
99 |
+
# Steps to get prediction of the satellite image
|
100 |
sample_image_resized = resize_image(image)
|
101 |
y_pred = ensemble_predict(sample_image_resized)
|
102 |
y_pred = get_predictions(y_pred).squeeze()
|
|
|
113 |
ax.imshow(sample_image_resized)
|
114 |
|
115 |
# Display the predictions using the specified colormap
|
116 |
+
cax = ax.imshow(y_pred, cmap=cmap, vmin=1, vmax=NUM_CLASSES, alpha=0.5)
|
117 |
|
118 |
# Create colorbar and set ticks and ticklabels
|
119 |
+
cbar = plt.colorbar(cax, ticks=np.arange(1, NUM_CLASSES + 1))
|
120 |
cbar.set_ticklabels(['Urban', 'Agriculture', 'Range Land', 'Forest', 'Water', 'Barren', 'Unknown'])
|
121 |
|
122 |
# Convert the figure to a PIL Image
|
|
|
128 |
# Close the figure to release resources
|
129 |
plt.close(fig)
|
130 |
|
131 |
+
return ground_truth_mask_pil, image_pil
|
132 |
|
133 |
def predict_on_test(image):
|
134 |
|
|
|
167 |
return image_pil
|
168 |
|
169 |
|
170 |
+
train_images = get_sample_images(TRAIN_FOLDER)
|
171 |
+
sample_images = get_sample_images(TEST_FOLDER)
|
172 |
|
173 |
description= '''
|
174 |
The DeepGlobe Land Cover Classification Challenge offers the first public dataset containing high resolution
|