vrk05 commited on
Commit
0349776
Β·
1 Parent(s): 54fd9ee
Files changed (1) hide show
  1. app.py +217 -11
app.py CHANGED
@@ -1,15 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
5
 
6
- def predict(image):
7
- predictions = pipeline(image)
8
- return {p["label"]: p["score"] for p in predictions}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- gr.Interface(
11
- predict,
12
- inputs=gr.inputs.Image(label="Upload hot dog candidate", type="filepath"),
13
- outputs=gr.outputs.Label(num_top_classes=2),
14
- title="Hot Dog? Or Not?",
15
- ).launch()
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download
3
+ config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-sen1floods11", filename="sen1floods11_Prithvi_100M.py", token=os.environ.get("token"))
4
+ ckpt=hf_hub_download(repo_id="Flooding_IBM", filename='floods_fine.pth', token=os.environ.get("token"))
5
+
6
+ import argparse
7
+ from mmcv import Config
8
+
9
+ from mmseg.models import build_segmentor
10
+
11
+ from mmseg.datasets.pipelines import Compose, LoadImageFromFile
12
+
13
+ import rasterio
14
+ import torch
15
+
16
+ from mmseg.apis import init_segmentor
17
+
18
+ from mmcv.parallel import collate, scatter
19
+
20
+ import numpy as np
21
+ import glob
22
+ import os
23
+
24
+ import time
25
+
26
+ import numpy as np
27
  import gradio as gr
28
+ from functools import partial
29
+
30
+ import pdb
31
+
32
+ import matplotlib.pyplot as plt
33
+
34
+ from skimage import exposure
35
+
36
+
37
+ def stretch_rgb(rgb):
38
+
39
+ ls_pct=1
40
+ pLow, pHigh = np.percentile(rgb[~np.isnan(rgb)], (ls_pct,100-ls_pct))
41
+ img_rescale = exposure.rescale_intensity(rgb, in_range=(pLow,pHigh))
42
+
43
+ return img_rescale
44
+
45
+
46
+ def open_tiff(fname):
47
+
48
+ with rasterio.open(fname, "r") as src:
49
+
50
+ data = src.read()
51
+
52
+ return data
53
+
54
+
55
+ def write_tiff(img_wrt, filename, metadata):
56
+
57
+ """
58
+ It writes a raster image to file.
59
+ :param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands)
60
+ :param filename: file path to the output file
61
+ :param metadata: metadata to use to write the raster to disk
62
+ :return:
63
+ """
64
+
65
+ with rasterio.open(filename, "w", **metadata) as dest:
66
+
67
+ if len(img_wrt.shape) == 2:
68
+
69
+ img_wrt = img_wrt[None]
70
+
71
+ for i in range(img_wrt.shape[0]):
72
+ dest.write(img_wrt[i, :, :], i + 1)
73
+
74
+ return filename
75
+
76
+
77
+ def get_meta(fname):
78
+
79
+ with rasterio.open(fname, "r") as src:
80
+
81
+ meta = src.meta
82
+
83
+ return meta
84
+
85
+ def preprocess_example(example_list):
86
+
87
+ example_list = [os.path.join(os.path.abspath(''), x) for x in example_list]
88
+
89
+ return example_list
90
+
91
+
92
+ def inference_segmentor(model, imgs, custom_test_pipeline=None):
93
+ """Inference image(s) with the segmentor.
94
+ Args:
95
+ model (nn.Module): The loaded segmentor.
96
+ imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
97
+ images.
98
+ Returns:
99
+ (list[Tensor]): The segmentation result.
100
+ """
101
+ cfg = model.cfg
102
+ device = next(model.parameters()).device # model device
103
+ # build the data pipeline
104
+ test_pipeline = [LoadImageFromFile()] + cfg.data.test.pipeline[1:] if custom_test_pipeline == None else custom_test_pipeline
105
+ test_pipeline = Compose(test_pipeline)
106
+ # prepare data
107
+ data = []
108
+ imgs = imgs if isinstance(imgs, list) else [imgs]
109
+ for img in imgs:
110
+ img_data = {'img_info': {'filename': img}}
111
+ img_data = test_pipeline(img_data)
112
+ data.append(img_data)
113
+ # print(data.shape)
114
+
115
+ data = collate(data, samples_per_gpu=len(imgs))
116
+ if next(model.parameters()).is_cuda:
117
+ # data = collate(data, samples_per_gpu=len(imgs))
118
+ # scatter to specified GPU
119
+ data = scatter(data, [device])[0]
120
+ else:
121
+ # img_metas = scatter(data['img_metas'],'cpu')
122
+ # data['img_metas'] = [i.data[0] for i in data['img_metas']]
123
+
124
+ img_metas = data['img_metas'].data[0]
125
+ img = data['img']
126
+ data = {'img': img, 'img_metas':img_metas}
127
+
128
+ with torch.no_grad():
129
+ result = model(return_loss=False, rescale=True, **data)
130
+ return result
131
+
132
+ def inference_on_file(target_image, model, custom_test_pipeline):
133
+
134
+ target_image = target_image.name
135
+
136
+ time_taken=-1
137
+
138
+ st = time.time()
139
+ print('Running inference...')
140
+ result = inference_segmentor(model, target_image, custom_test_pipeline)
141
+
142
+ print("Output has shape: " + str(result[0].shape))
143
+
144
+ ##### prep outputs
145
+ mask = open_tiff(target_image)
146
+ rgb = stretch_rgb((mask[[3, 2, 1], :, :].transpose((1,2,0))/10000*255).astype(np.uint8))
147
+ meta = get_meta(target_image)
148
+ mask = np.where(mask == meta['nodata'], 1, 0)
149
+ mask = np.max(mask, axis=0)[None]
150
+ rgb = np.where(mask.transpose((1,2,0)) == 1, 0, rgb)
151
+ rgb = np.where(rgb < 0, 0, rgb)
152
+ rgb = np.where(rgb > 255, 255, rgb)
153
+
154
+ prediction = np.where(mask == 1, 0, result[0]*255)
155
+ et = time.time()
156
+ time_taken = np.round(et - st, 1)
157
+ print(f'Inference completed in {str(time_taken)} seconds')
158
+
159
+ return rgb, prediction[0]
160
+
161
+ def process_test_pipeline(custom_test_pipeline, bands=None):
162
+
163
+ # change extracted bands if necessary
164
+ if bands is not None:
165
+
166
+ extract_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'] == 'BandsExtract' ]
167
+
168
+ if len(extract_index) > 0:
169
+
170
+ custom_test_pipeline[extract_index[0]]['bands'] = eval(bands)
171
+
172
+ collect_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'].find('Collect') > -1]
173
+
174
+ # adapt collected keys if necessary
175
+ if len(collect_index) > 0:
176
+
177
+ keys = ['img_info', 'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']
178
+ custom_test_pipeline[collect_index[0]]['meta_keys'] = keys
179
+
180
+ return custom_test_pipeline
181
+
182
+ config = Config.fromfile(config_path)
183
+ config.model.backbone.pretrained=None
184
+ model = init_segmentor(config, ckpt, device='cpu')
185
+ custom_test_pipeline=process_test_pipeline(model.cfg.data.test.pipeline, None)
186
 
187
+ func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline)
188
 
189
+ with gr.Blocks() as demo:
190
+
191
+ gr.Markdown(value='# Prithvi sen1floods11')
192
+ gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to detect water at a higher resolution than it was trained on (i.e. 10m versus 30m) using Sentinel 2 imagery from on the [sen1floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11). More detailes can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11).\n
193
+ The user needs to provide a Sentinel 2 image with all the 12 bands (in the usual Sentinel 2) order in reflectance units multiplied by 10,000 (e.g. to save on space), with the code that is going to pull up Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
194
+ ''')
195
+ with gr.Row():
196
+ with gr.Column():
197
+ inp = gr.File()
198
+ btn = gr.Button("Submit")
199
+
200
+ with gr.Row():
201
+ gr.Markdown(value='### Input RGB')
202
+ gr.Markdown(value='### Model prediction (Black: Land; White: Water)')
203
+
204
+ with gr.Row():
205
+ out1=gr.Image(image_mode='RGB')
206
+ out2 = gr.Image(image_mode='L')
207
+
208
+ btn.click(fn=func, inputs=inp, outputs=[out1, out2])
209
+
210
+ with gr.Row():
211
+ gr.Examples(examples=["India_900498_S2Hand.tif",
212
+ "Spain_7370579_S2Hand.tif",
213
+ "USA_430764_S2Hand.tif"],
214
+ inputs=inp,
215
+ outputs=[out1, out2],
216
+ preprocess=preprocess_example,
217
+ fn=func,
218
+ cache_examples=True,
219
+ )
220
 
221
+ demo.launch()