mjms commited on
Commit
181162f
1 Parent(s): ec2a9c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This was made by following this tutorial
2
+ # https://www.youtube.com/watch?v=i40ulpcacFM
3
+
4
+ !pip install -U -q segmentation-models
5
+ # # Open the file in write mode
6
+ # with open('/usr/local/lib/python3.9/dist-packages/efficientnet/keras.py', 'r') as f:
7
+ # # Read the contents of the file
8
+ # contents = f.read()
9
+
10
+ # # Replace the string
11
+ # new_contents = contents.replace('init_keras_custom_objects', 'init_tfkeras_custom_objects')
12
+
13
+ # # Open the file in write mode again and write the modified contents
14
+ # with open('/usr/local/lib/python3.9/dist-packages/efficientnet/keras.py', 'w') as f:
15
+ # f.write(new_contents)
16
+
17
+ !pip install patchify
18
+ !pip install gradio
19
+
20
+ import os
21
+ from os.path import join as pjoin
22
+ import cv2
23
+ import numpy as np
24
+ from tqdm import tqdm
25
+ from matplotlib import pyplot as plt
26
+ from PIL import Image
27
+ import seaborn as sns
28
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
29
+ from patchify import patchify, unpatchify
30
+
31
+ from keras import backend as K
32
+ from keras.models import load_model
33
+
34
+ import segmentation_models as sm
35
+
36
+
37
+ import gradio as gr
38
+
39
+ def jaccard_coef(y_true, y_pred):
40
+ y_true_f = K.flatten(y_true)
41
+ y_pred_f = K.flatten(y_pred)
42
+ intersection = K.sum(y_true_f * y_pred_f)
43
+ return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)
44
+
45
+ weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
46
+ dice_loss = sm.losses.DiceLoss(class_weights=weights)
47
+ focal_loss = sm.losses.CategoricalFocalLoss()
48
+ total_loss = dice_loss + (1 * focal_loss)
49
+
50
+ model_path = 'models/satellite_segmentation_100-epochs.h5'
51
+ saved_model = load_model(model_path,
52
+ custom_objects=({'dice_loss_plus_1focal_loss': total_loss,
53
+ 'jaccard_coef': jaccard_coef}))
54
+
55
+
56
+ def process_input_image(test_image):
57
+ test_dataset = []
58
+ image_patch_size = 256
59
+ scaler = MinMaxScaler()
60
+
61
+ # crop images so that they are divisible by image_patch_size
62
+ test_image = np.array(test_image)
63
+ size_x = (test_image.shape[1]//image_patch_size)*image_patch_size
64
+ size_y = (test_image.shape[0]//image_patch_size)*image_patch_size
65
+
66
+ test_image = Image.fromarray(test_image)
67
+ test_image = test_image.crop((0, 0, size_x, size_y))
68
+
69
+ # patchify image so that each patch is size (image_patch_size,image_patch_size)
70
+ test_image = np.array(test_image)
71
+ image_patches = patchify(test_image, (image_patch_size,image_patch_size, 3), step = image_patch_size) # 3 should probably be a variable since we have have many more channels than RGB
72
+
73
+ # scale values so that they are between 0 to 1
74
+ # here, we use MinMaxScaler from sklearn
75
+
76
+ for i in range(image_patches.shape[0]):
77
+ for j in range(image_patches.shape[1]):
78
+ image_patch = image_patches[i,j,:,:]
79
+
80
+ image_patch = scaler.fit_transform(image_patch.reshape(-1, image_patch.shape[-1])).reshape(image_patch.shape)
81
+
82
+ image_patch = image_patch[0] # drop extra unessesary dimantion that patchify adds
83
+ test_dataset.append(image_patch)
84
+
85
+ test_dataset = [np.expand_dims(np.array(x), 0) for x in test_dataset]
86
+
87
+ test_prediction = []
88
+
89
+ for image in tqdm(test_dataset):
90
+ prediction = saved_model.predict(image,verbose=0)
91
+ predicted_image = np.argmax(prediction, axis=3)
92
+ predicted_image = predicted_image[0,:,:]
93
+ test_prediction.append(predicted_image)
94
+
95
+
96
+ reconstructed_image = np.reshape(np.array(test_prediction),(image_patches.shape[0],image_patches.shape[1],image_patch_size,image_patch_size))
97
+ reconstructed_image = unpatchify(reconstructed_image , (size_y,size_x))
98
+
99
+ lookup = {'rgb': [np.array([ 60, 16, 152]),
100
+ np.array([132, 41, 246]),
101
+ np.array([110, 193, 228]),
102
+ np.array([254, 221, 58]),
103
+ np.array([226, 169, 41]),
104
+ np.array([155, 155, 155])],
105
+ 'int': [0, 1, 2, 3, 4, 5]}
106
+
107
+ rgb_image = np.zeros((reconstructed_image.shape[0],reconstructed_image.shape[1],3), dtype=np.uint8)
108
+
109
+ for i,l in enumerate(lookup['int']):
110
+ rgb_image[np.where(reconstructed_image==l)] = lookup['rgb'][i]
111
+ return 'Predicted Masked Image', rgb_image
112
+
113
+
114
+ my_app = gr.Blocks()
115
+ with my_app:
116
+ gr.Markdown("Statellite Image Segmentation Application UI with Gradio")
117
+ with gr.Tabs():
118
+ with gr.TabItem("Select your image"):
119
+ with gr.Row():
120
+ with gr.Column():
121
+ img_source = gr.Image(label="Please select source Image")
122
+ source_image_loader = gr.Button("Load above Image")
123
+ with gr.Column():
124
+ output_label = gr.Label(label="Image Info")
125
+ img_output = gr.Image(label="Image Output")
126
+ source_image_loader.click(
127
+ process_input_image,
128
+ [
129
+ img_source
130
+ ],
131
+ [
132
+ output_label,
133
+ img_output
134
+ ]
135
+ )
136
+
137
+ my_app.launch(debug=True)