Paolo-Fraccaro commited on
Commit
5bac81d
·
1 Parent(s): dab7700

update app

Browse files
Files changed (2) hide show
  1. .gitattributes +1 -0
  2. app.py +16 -59
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.tif* filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -5,11 +5,7 @@ config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temp
5
  filename="multi_temporal_crop_classification_Prithvi_100M.py",
6
  token=os.environ.get("token"))
7
  ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
8
- <<<<<<< HEAD
9
  filename='multi_temporal_crop_classification_Prithvi_100M.pth',
10
- =======
11
- filename='multi_temporal_crop_classification_best_mIoU_epoch_66.pth',
12
- >>>>>>> 889a651 (add files)
13
  token=os.environ.get("token"))
14
  ##########
15
  import argparse
@@ -40,6 +36,15 @@ import pdb
40
 
41
  import matplotlib.pyplot as plt
42
 
 
 
 
 
 
 
 
 
 
43
 
44
  def open_tiff(fname):
45
 
@@ -137,7 +142,6 @@ def inference_on_file(target_image, model, custom_test_pipeline):
137
 
138
  # output_image = target_image.replace('.tif', '_pred.tif')
139
  time_taken=-1
140
- <<<<<<< HEAD
141
  st = time.time()
142
  print('Running inference...')
143
  result = inference_segmentor(model, target_image, custom_test_pipeline)
@@ -146,9 +150,9 @@ def inference_on_file(target_image, model, custom_test_pipeline):
146
  ##### get metadata mask
147
  mask = open_tiff(target_image)
148
  # rgb = mask[[2, 1, 0], :, :].transpose((1,2,0))
149
- rgb1 = mask[[2, 1, 0], :, :].transpose((1,2,0))
150
- rgb2 = mask[[8, 7, 6], :, :].transpose((1,2,0))
151
- rgb3 = mask[[14, 13, 12], :, :].transpose((1,2,0))
152
  meta = get_meta(target_image)
153
  mask = np.where(mask == meta['nodata'], 1, 0)
154
  mask = np.max(mask, axis=0)[None]
@@ -165,43 +169,11 @@ def inference_on_file(target_image, model, custom_test_pipeline):
165
  et = time.time()
166
  time_taken = np.round(et - st, 1)
167
  print(f'Inference completed in {str(time_taken)} seconds')
 
 
 
168
 
169
- return rgb1,rgb2,rgb3, result[0][0]
170
- =======
171
- try:
172
- st = time.time()
173
- print('Running inference...')
174
- result = inference_segmentor(model, target_image, custom_test_pipeline)
175
- print("Output has shape: " + str(result[0].shape))
176
-
177
- ##### get metadata mask
178
- mask = open_tiff(target_image)
179
- # rgb = mask[[2, 1, 0], :, :].transpose((1,2,0))
180
- rgb1 = mask[[2, 1, 0], :, :].transpose((1,2,0))
181
- rgb2 = mask[[8, 7, 6], :, :].transpose((1,2,0))
182
- rgb3 = mask[[14, 13, 12], :, :].transpose((1,2,0))
183
- meta = get_meta(target_image)
184
- mask = np.where(mask == meta['nodata'], 1, 0)
185
- mask = np.max(mask, axis=0)[None]
186
-
187
- result[0] = np.where(mask == 1, -1, result[0])
188
-
189
- ##### Save file to disk
190
- meta["count"] = 1
191
- meta["dtype"] = "int16"
192
- meta["compress"] = "lzw"
193
- meta["nodata"] = -1
194
- print('Saving output...')
195
- # write_tiff(result[0], output_image, meta)
196
- et = time.time()
197
- time_taken = np.round(et - st, 1)
198
- print(f'Inference completed in {str(time_taken)} seconds')
199
-
200
- except:
201
- print(f'Error on image {target_image} \nContinue to next input')
202
-
203
- return rgb, result[0][0]*255
204
- >>>>>>> 889a651 (add files)
205
 
206
  def process_test_pipeline(custom_test_pipeline, bands=None):
207
 
@@ -224,10 +196,6 @@ def process_test_pipeline(custom_test_pipeline, bands=None):
224
 
225
  return custom_test_pipeline
226
 
227
- <<<<<<< HEAD
228
-
229
- =======
230
- >>>>>>> 889a651 (add files)
231
  config = Config.fromfile(config_path)
232
  config.model.backbone.pretrained=None
233
  model = init_segmentor(config, ckpt, device='cpu')
@@ -260,16 +228,6 @@ with gr.Blocks() as demo:
260
 
261
  btn.click(fn=func, inputs=inp, outputs=[inp1, inp2, inp3, out])
262
 
263
- <<<<<<< HEAD
264
- # with gr.Row():
265
- # gr.Examples(examples=["chip_102_345_merged.tif", "chip_104_104_merged.tif", "chip_109_421_merged.tif"],
266
- # inputs=inp,
267
- # outputs=[inp1, inp2, inp3, out],
268
- # preprocess=preprocess_example,
269
- # fn=func,
270
- # cache_examples=True,
271
- # )
272
- =======
273
  with gr.Row():
274
  gr.Examples(examples=["chip_102_345_merged.tif",
275
  "chip_104_104_merged.tif",
@@ -280,6 +238,5 @@ with gr.Blocks() as demo:
280
  fn=func,
281
  cache_examples=True,
282
  )
283
- >>>>>>> 889a651 (add files)
284
 
285
  demo.launch()
 
5
  filename="multi_temporal_crop_classification_Prithvi_100M.py",
6
  token=os.environ.get("token"))
7
  ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
 
8
  filename='multi_temporal_crop_classification_Prithvi_100M.pth',
 
 
 
9
  token=os.environ.get("token"))
10
  ##########
11
  import argparse
 
36
 
37
  import matplotlib.pyplot as plt
38
 
39
+ from skimage import exposure
40
+
41
+ def stretch_rgb(rgb):
42
+
43
+ ls_pct=0
44
+ pLow, pHigh = np.percentile(rgb1[~np.isnan(rgb1)], (ls_pct,100-ls_pct))
45
+ img_rescale = exposure.rescale_intensity(rgb1, in_range=(pLow,pHigh))
46
+
47
+ return img_rescale
48
 
49
  def open_tiff(fname):
50
 
 
142
 
143
  # output_image = target_image.replace('.tif', '_pred.tif')
144
  time_taken=-1
 
145
  st = time.time()
146
  print('Running inference...')
147
  result = inference_segmentor(model, target_image, custom_test_pipeline)
 
150
  ##### get metadata mask
151
  mask = open_tiff(target_image)
152
  # rgb = mask[[2, 1, 0], :, :].transpose((1,2,0))
153
+ rgb1 = stretch_rgb((mask[[2, 1, 0], :, :].transpose((1,2,0))/10000*255).astype(np.unint8))
154
+ rgb2 = stretch_rgb((mask[[8, 7, 6], :, :].transpose((1,2,0))/10000*255).astype(np.unint8))
155
+ rgb3 = stretch_rgb((mask[[14, 13, 12], :, :].transpose((1,2,0))/10000*255).astype(np.unint8))
156
  meta = get_meta(target_image)
157
  mask = np.where(mask == meta['nodata'], 1, 0)
158
  mask = np.max(mask, axis=0)[None]
 
169
  et = time.time()
170
  time_taken = np.round(et - st, 1)
171
  print(f'Inference completed in {str(time_taken)} seconds')
172
+
173
+ output = (result[0][0]*18).astype(np.uint8)
174
+ output = np.vstack([output[None], output[None], output[None]])
175
 
176
+ return rgb1,rgb2,rgb3,output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  def process_test_pipeline(custom_test_pipeline, bands=None):
179
 
 
196
 
197
  return custom_test_pipeline
198
 
 
 
 
 
199
  config = Config.fromfile(config_path)
200
  config.model.backbone.pretrained=None
201
  model = init_segmentor(config, ckpt, device='cpu')
 
228
 
229
  btn.click(fn=func, inputs=inp, outputs=[inp1, inp2, inp3, out])
230
 
 
 
 
 
 
 
 
 
 
 
231
  with gr.Row():
232
  gr.Examples(examples=["chip_102_345_merged.tif",
233
  "chip_104_104_merged.tif",
 
238
  fn=func,
239
  cache_examples=True,
240
  )
 
241
 
242
  demo.launch()