aswinjosephe commited on
Commit
52468b4
1 Parent(s): 814a37c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -4
app.py CHANGED
@@ -1,7 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch(share=True)
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download
3
+ config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-sen1floods11", filename="sen1floods11_Prithvi_100M.py", token=os.environ.get("token"))
4
+ ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-sen1floods11", filename='sen1floods11_Prithvi_100M.pth', token=os.environ.get("token"))
5
+ ##########
6
+
7
+
8
+ import argparse
9
+ from mmcv import Config
10
+
11
+ from mmseg.models import build_segmentor
12
+
13
+ from mmseg.datasets.pipelines import Compose, LoadImageFromFile
14
+
15
+ import rasterio
16
+ import torch
17
+
18
+ from mmseg.apis import init_segmentor
19
+
20
+ from mmcv.parallel import collate, scatter
21
+
22
+ import numpy as np
23
+ import glob
24
+ import os
25
+
26
+ import time
27
+
28
+ import numpy as np
29
  import gradio as gr
30
+ from functools import partial
31
+
32
+ import pdb
33
+
34
+ import matplotlib.pyplot as plt
35
+
36
+ from skimage import exposure
37
+
38
+ def stretch_rgb(rgb):
39
+
40
+ ls_pct=1
41
+ pLow, pHigh = np.percentile(rgb[~np.isnan(rgb)], (ls_pct,100-ls_pct))
42
+ img_rescale = exposure.rescale_intensity(rgb, in_range=(pLow,pHigh))
43
+
44
+ return img_rescale
45
+
46
+
47
+ def open_tiff(fname):
48
+
49
+ with rasterio.open(fname, "r") as src:
50
+
51
+ data = src.read()
52
+
53
+ return data
54
+
55
+ def write_tiff(img_wrt, filename, metadata):
56
+
57
+ """
58
+ It writes a raster image to file.
59
+ :param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands)
60
+ :param filename: file path to the output file
61
+ :param metadata: metadata to use to write the raster to disk
62
+ :return:
63
+ """
64
+
65
+ with rasterio.open(filename, "w", **metadata) as dest:
66
+
67
+ if len(img_wrt.shape) == 2:
68
+
69
+ img_wrt = img_wrt[None]
70
+
71
+ for i in range(img_wrt.shape[0]):
72
+ dest.write(img_wrt[i, :, :], i + 1)
73
+
74
+ return filename
75
+
76
+
77
+ def get_meta(fname):
78
+
79
+ with rasterio.open(fname, "r") as src:
80
+
81
+ meta = src.meta
82
+
83
+ return meta
84
+
85
+ def preprocess_example(example_list):
86
+
87
+ example_list = [os.path.join(os.path.abspath(''), x) for x in example_list]
88
+
89
+ return example_list
90
+
91
+
92
+ def inference_segmentor(model, imgs, custom_test_pipeline=None):
93
+ """Inference image(s) with the segmentor.
94
+ Args:
95
+ model (nn.Module): The loaded segmentor.
96
+ imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
97
+ images.
98
+ Returns:
99
+ (list[Tensor]): The segmentation result.
100
+ """
101
+ cfg = model.cfg
102
+ device = next(model.parameters()).device # model device
103
+ # build the data pipeline
104
+ test_pipeline = [LoadImageFromFile()] + cfg.data.test.pipeline[1:] if custom_test_pipeline == None else custom_test_pipeline
105
+ test_pipeline = Compose(test_pipeline)
106
+ # prepare data
107
+ data = []
108
+ imgs = imgs if isinstance(imgs, list) else [imgs]
109
+ for img in imgs:
110
+ img_data = {'img_info': {'filename': img}}
111
+ img_data = test_pipeline(img_data)
112
+ data.append(img_data)
113
+ # print(data.shape)
114
+
115
+ data = collate(data, samples_per_gpu=len(imgs))
116
+ if next(model.parameters()).is_cuda:
117
+ # data = collate(data, samples_per_gpu=len(imgs))
118
+ # scatter to specified GPU
119
+ data = scatter(data, [device])[0]
120
+ else:
121
+ # img_metas = scatter(data['img_metas'],'cpu')
122
+ # data['img_metas'] = [i.data[0] for i in data['img_metas']]
123
+
124
+ img_metas = data['img_metas'].data[0]
125
+ img = data['img']
126
+ data = {'img': img, 'img_metas':img_metas}
127
+
128
+ with torch.no_grad():
129
+ result = model(return_loss=False, rescale=True, **data)
130
+ return result
131
+
132
+
133
+ def inference_on_file(target_image, model, custom_test_pipeline):
134
+
135
+ target_image = target_image.name
136
+
137
+ time_taken=-1
138
+
139
+ st = time.time()
140
+ print('Running inference...')
141
+ result = inference_segmentor(model, target_image, custom_test_pipeline)
142
+
143
+ print("Output has shape: " + str(result[0].shape))
144
+
145
+ ##### prep outputs
146
+ mask = open_tiff(target_image)
147
+ rgb = stretch_rgb((mask[[3, 2, 1], :, :].transpose((1,2,0))/10000*255).astype(np.uint8))
148
+ meta = get_meta(target_image)
149
+ mask = np.where(mask == meta['nodata'], 1, 0)
150
+ mask = np.max(mask, axis=0)[None]
151
+ rgb = np.where(mask.transpose((1,2,0)) == 1, 0, rgb)
152
+ rgb = np.where(rgb < 0, 0, rgb)
153
+ rgb = np.where(rgb > 255, 255, rgb)
154
+
155
+ prediction = np.where(mask == 1, 0, result[0]*255)
156
+ et = time.time()
157
+ time_taken = np.round(et - st, 1)
158
+ print(f'Inference completed in {str(time_taken)} seconds')
159
+
160
+ return rgb, prediction[0]
161
+
162
+ def process_test_pipeline(custom_test_pipeline, bands=None):
163
+
164
+ # change extracted bands if necessary
165
+ if bands is not None:
166
+
167
+ extract_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'] == 'BandsExtract' ]
168
+
169
+ if len(extract_index) > 0:
170
+
171
+ custom_test_pipeline[extract_index[0]]['bands'] = eval(bands)
172
+
173
+ collect_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'].find('Collect') > -1]
174
+
175
+ # adapt collected keys if necessary
176
+ if len(collect_index) > 0:
177
+
178
+ keys = ['img_info', 'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']
179
+ custom_test_pipeline[collect_index[0]]['meta_keys'] = keys
180
+
181
+ return custom_test_pipeline
182
+
183
+ config = Config.fromfile(config_path)
184
+ config.model.backbone.pretrained=None
185
+ model = init_segmentor(config, ckpt, device='cpu')
186
+ custom_test_pipeline=process_test_pipeline(model.cfg.data.test.pipeline, None)
187
+
188
+ func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline)
189
 
190
+ with gr.Blocks() as demo:
191
+
192
+ gr.Markdown(value='# Prithvi sen1floods11')
193
+ gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to detect water at a higher resolution than it was trained on (i.e. 10m versus 30m) using Sentinel 2 imagery from on the [sen1floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11). More detailes can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11).\n
194
+ The user needs to provide a Sentinel 2 image with all the 12 bands (in the usual Sentinel 2) order in reflectance units multiplied by 10,000 (e.g. to save on space), with the code that is going to pull up Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
195
+ ''')
196
+ with gr.Row():
197
+ with gr.Column():
198
+ inp = gr.File()
199
+ btn = gr.Button("Submit")
200
+
201
+ with gr.Row():
202
+ gr.Markdown(value='### Input RGB')
203
+ gr.Markdown(value='### Model prediction (Black: Land; White: Water)')
204
+
205
+ with gr.Row():
206
+ out1=gr.Image(image_mode='RGB')
207
+ out2 = gr.Image(image_mode='L')
208
+
209
+ btn.click(fn=func, inputs=inp, outputs=[out1, out2])
210
+
211
+ with gr.Row():
212
+ gr.Examples(examples=["India_900498_S2Hand.tif",
213
+ "Spain_7370579_S2Hand.tif",
214
+ "USA_430764_S2Hand.tif"],
215
+ inputs=inp,
216
+ outputs=[out1, out2],
217
+ preprocess=preprocess_example,
218
+ fn=func,
219
+ cache_examples=True,
220
+ )
221
 
222
+ demo.launch()