amosfang commited on
Commit
a0d7f62
1 Parent(s): a2eb125

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -8
app.py CHANGED
@@ -14,6 +14,9 @@ import os
14
  import io
15
 
16
  REPO_ID = "amosfang/segmentation_u_net"
 
 
 
17
 
18
  def pil_image_as_numpy_array(pilimg):
19
  img_array = tf.keras.utils.img_to_array(pilimg)
@@ -39,13 +42,26 @@ def get_sample_images(image_folder, format=('.jpg', '.jpeg')):
39
 
40
  # Get a list of all files in the folder
41
  img_file_list = os.listdir(image_folder)
42
- img_file_list.sort()
43
 
44
- # Filter out only the image files (assuming images have extensions like '.jpg' or '.png')
45
  image_files = [[image_folder +'/' + file] for file in img_file_list if file.lower().endswith(format)]
46
 
47
  return image_files
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def ensemble_predict(X_array):
50
  #
51
  # Call the predict methods of the unet_model and the vgg16_unet_model
@@ -75,8 +91,12 @@ def get_predictions(y_prediction_encoded):
75
  return predicted_label_indices
76
 
77
  def predict_on_train(image):
 
 
 
 
78
 
79
- # Steps to get prediction
80
  sample_image_resized = resize_image(image)
81
  y_pred = ensemble_predict(sample_image_resized)
82
  y_pred = get_predictions(y_pred).squeeze()
@@ -93,10 +113,10 @@ def predict_on_train(image):
93
  ax.imshow(sample_image_resized)
94
 
95
  # Display the predictions using the specified colormap
96
- cax = ax.imshow(y_pred, cmap=cmap, vmin=1, vmax=7, alpha=0.5)
97
 
98
  # Create colorbar and set ticks and ticklabels
99
- cbar = plt.colorbar(cax, ticks=np.arange(1, 8))
100
  cbar.set_ticklabels(['Urban', 'Agriculture', 'Range Land', 'Forest', 'Water', 'Barren', 'Unknown'])
101
 
102
  # Convert the figure to a PIL Image
@@ -108,7 +128,7 @@ def predict_on_train(image):
108
  # Close the figure to release resources
109
  plt.close(fig)
110
 
111
- return image_pil, image_pil
112
 
113
  def predict_on_test(image):
114
 
@@ -147,8 +167,8 @@ def predict_on_test(image):
147
  return image_pil
148
 
149
 
150
- sample_images = get_sample_images('example_images')
151
- train_images = get_sample_images('train_images')
152
 
153
  description= '''
154
  The DeepGlobe Land Cover Classification Challenge offers the first public dataset containing high resolution
 
14
  import io
15
 
16
  REPO_ID = "amosfang/segmentation_u_net"
17
+ TRAIN_FOLDER = 'train_images'
18
+ TEST_FOLDER = 'example_images'
19
+ NUM_CLASSES = 7
20
 
21
  def pil_image_as_numpy_array(pilimg):
22
  img_array = tf.keras.utils.img_to_array(pilimg)
 
42
 
43
  # Get a list of all files in the folder
44
  img_file_list = os.listdir(image_folder)
 
45
 
46
+ # Filter out only the image files (assuming images have extensions like '.jpg')
47
  image_files = [[image_folder +'/' + file] for file in img_file_list if file.lower().endswith(format)]
48
 
49
  return image_files
50
 
51
+ def get_sample_mask(image_folder, image):
52
+
53
+ # Get the filename of the original image
54
+ image_filename = os.path.basename(image.name)
55
+
56
+ # Construct the filename for the ground truth mask
57
+ mask_filename = image_filename.replace('_sat.jpg', '_mask.png')
58
+
59
+ # Load the ground truth mask
60
+ mask_path = os.path.join(image_folder, mask_filename)
61
+ ground_truth_mask = Image.open(mask_path)
62
+
63
+ return ground_truth_mask
64
+
65
  def ensemble_predict(X_array):
66
  #
67
  # Call the predict methods of the unet_model and the vgg16_unet_model
 
91
  return predicted_label_indices
92
 
93
  def predict_on_train(image):
94
+
95
+ # Steps to get the ground truth image mask
96
+ ground_truth_mask = get_sample_mask(TRAIN_FOLDER, image)
97
+ ground_truth_mask_pil = resize_image(ground_truth_mask)
98
 
99
+ # Steps to get prediction of the satellite image
100
  sample_image_resized = resize_image(image)
101
  y_pred = ensemble_predict(sample_image_resized)
102
  y_pred = get_predictions(y_pred).squeeze()
 
113
  ax.imshow(sample_image_resized)
114
 
115
  # Display the predictions using the specified colormap
116
+ cax = ax.imshow(y_pred, cmap=cmap, vmin=1, vmax=NUM_CLASSES, alpha=0.5)
117
 
118
  # Create colorbar and set ticks and ticklabels
119
+ cbar = plt.colorbar(cax, ticks=np.arange(1, NUM_CLASSES + 1))
120
  cbar.set_ticklabels(['Urban', 'Agriculture', 'Range Land', 'Forest', 'Water', 'Barren', 'Unknown'])
121
 
122
  # Convert the figure to a PIL Image
 
128
  # Close the figure to release resources
129
  plt.close(fig)
130
 
131
+ return ground_truth_mask_pil, image_pil
132
 
133
  def predict_on_test(image):
134
 
 
167
  return image_pil
168
 
169
 
170
+ train_images = get_sample_images(TRAIN_FOLDER)
171
+ sample_images = get_sample_images(TEST_FOLDER)
172
 
173
  description= '''
174
  The DeepGlobe Land Cover Classification Challenge offers the first public dataset containing high resolution