Paolo-Fraccaro commited on
Commit
c89f816
0 Parent(s):
Files changed (4) hide show
  1. Dockerfile +63 -0
  2. README.md +11 -0
  3. app.py +198 -0
  4. requirements.txt +4 -0
Dockerfile ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:18.04
2
+
3
+
4
+ RUN apt-get update && apt-get install --no-install-recommends -y \
5
+ build-essential \
6
+ python3.8 \
7
+ python3-pip \
8
+ python3-setuptools \
9
+ git \
10
+ wget \
11
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
12
+
13
+ WORKDIR /code
14
+
15
+ COPY ./requirements.txt /code/requirements.txt
16
+
17
+ # add conda
18
+ RUN RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh /code/
19
+ RUN chmod 777 /code/Miniconda3-latest-Linux-x86_64.sh
20
+
21
+
22
+ # Set up a new user named "user" with user ID 1000
23
+ RUN useradd -m -u 1000 user
24
+ # Switch to the "user" user
25
+ USER user
26
+ # Set home to the user's home directory
27
+ ENV HOME=/home/user \
28
+ PATH=/home/user/.local/bin:$PATH \
29
+ PYTHONPATH=$HOME/app \
30
+ PYTHONUNBUFFERED=1 \
31
+ GRADIO_ALLOW_FLAGGING=never \
32
+ GRADIO_NUM_PORTS=1 \
33
+ GRADIO_SERVER_NAME=0.0.0.0 \
34
+ GRADIO_THEME=huggingface \
35
+ SYSTEM=spaces
36
+
37
+ RUN /code/Miniconda3-latest-Linux-x86_64.sh -b -p /miniconda
38
+ ENV PATH="/miniconda/bin:${PATH}"
39
+
40
+
41
+ # RUN /miniconda/bin/conda init bash
42
+
43
+ # RUN conda install python=3.9
44
+
45
+
46
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
47
+
48
+ RUN git clone git+https://$(cat /run/secrets/git_token)@github.com/NASA-IMPACT/hls-foundation-os.git
49
+
50
+ RUN pip3 install fine-tuning-examples/
51
+
52
+ # RUN --mount=type=secret,id=git_token,mode=0444,required=true \
53
+ # pip3 install git+https://$(cat /run/secrets/git_token)@github.com/NASA-IMPACT/hls-foundation-os.git@mmseg-only
54
+
55
+ RUN mim install mmcv-full==1.5.0
56
+
57
+ # Set the working directory to the user's home directory
58
+ WORKDIR $HOME/app
59
+
60
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
61
+ COPY --chown=user . $HOME/app
62
+
63
+ CMD ["python3", "app.py"]
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Prithvi 100M Burn Scars Demo
3
+ emoji: 🌖
4
+ colorFrom: purple
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ######### pull files
2
+ from huggingface_hub import hf_hub_download
3
+ config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/burn-scar-Prithvi-100M", filename="Prithvi_100M_config.yaml", token=os.environ.get("token"))
4
+ ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/burn-scar-Prithvi-100M", filename='Prithvi_100M.pt', token=os.environ.get("token"))
5
+ ##########
6
+
7
+
8
+ import argparse
9
+ from mmcv import Config
10
+
11
+ from mmseg.models import build_segmentor
12
+
13
+ from mmseg.datasets.pipelines import Compose, LoadImageFromFile
14
+
15
+ import rasterio
16
+ import torch
17
+
18
+ from mmseg.apis import init_segmentor
19
+
20
+ from mmcv.parallel import collate, scatter
21
+
22
+ import numpy as np
23
+ import glob
24
+ import os
25
+
26
+ import time
27
+
28
+ import numpy as np
29
+ import gradio as gr
30
+ from functools import partial
31
+
32
+ import pdb
33
+
34
+ import matplotlib.pyplot as plt
35
+
36
+
37
+ def open_tiff(fname):
38
+
39
+ with rasterio.open(fname, "r") as src:
40
+
41
+ data = src.read()
42
+
43
+ return data
44
+
45
+ def write_tiff(img_wrt, filename, metadata):
46
+
47
+ """
48
+ It writes a raster image to file.
49
+
50
+ :param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands)
51
+ :param filename: file path to the output file
52
+ :param metadata: metadata to use to write the raster to disk
53
+ :return:
54
+ """
55
+
56
+ with rasterio.open(filename, "w", **metadata) as dest:
57
+
58
+ if len(img_wrt.shape) == 2:
59
+
60
+ img_wrt = img_wrt[None]
61
+
62
+ for i in range(img_wrt.shape[0]):
63
+ dest.write(img_wrt[i, :, :], i + 1)
64
+
65
+ return filename
66
+
67
+
68
+ def get_meta(fname):
69
+
70
+ with rasterio.open(fname, "r") as src:
71
+
72
+ meta = src.meta
73
+
74
+ return meta
75
+
76
+
77
+
78
+ def inference_segmentor(model, imgs, custom_test_pipeline=None):
79
+ """Inference image(s) with the segmentor.
80
+
81
+ Args:
82
+ model (nn.Module): The loaded segmentor.
83
+ imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
84
+ images.
85
+
86
+ Returns:
87
+ (list[Tensor]): The segmentation result.
88
+ """
89
+ cfg = model.cfg
90
+ device = next(model.parameters()).device # model device
91
+ # build the data pipeline
92
+ test_pipeline = [LoadImageFromFile()] + cfg.data.test.pipeline[1:] if custom_test_pipeline == None else custom_test_pipeline
93
+ test_pipeline = Compose(test_pipeline)
94
+ # prepare data
95
+ data = []
96
+ imgs = imgs if isinstance(imgs, list) else [imgs]
97
+ for img in imgs:
98
+ img_data = {'img_info': {'filename': img}}
99
+ img_data = test_pipeline(img_data)
100
+ data.append(img_data)
101
+ # print(data.shape)
102
+
103
+ data = collate(data, samples_per_gpu=len(imgs))
104
+ if next(model.parameters()).is_cuda:
105
+ # data = collate(data, samples_per_gpu=len(imgs))
106
+ # scatter to specified GPU
107
+ data = scatter(data, [device])[0]
108
+ else:
109
+ # img_metas = scatter(data['img_metas'],'cpu')
110
+ # data['img_metas'] = [i.data[0] for i in data['img_metas']]
111
+
112
+ img_metas = data['img_metas'].data[0]
113
+ img = data['img']
114
+ data = {'img': img, 'img_metas':img_metas}
115
+
116
+ with torch.no_grad():
117
+ result = model(return_loss=False, rescale=True, **data)
118
+ return result
119
+
120
+
121
+ def inference_on_file(target_image, model, custom_test_pipeline):
122
+
123
+ target_image = target_image.name
124
+ # print(type(target_image))
125
+
126
+ # output_image = target_image.replace('.tif', '_pred.tif')
127
+ time_taken=-1
128
+ try:
129
+ st = time.time()
130
+ print('Running inference...')
131
+ result = inference_segmentor(model, target_image, custom_test_pipeline)
132
+ print("Output has shape: " + str(result[0].shape))
133
+
134
+ ##### get metadata mask
135
+ mask = open_tiff(target_image)
136
+ # rgb = mask[[2, 1, 0], :, :].transpose((1,2,0))
137
+ rgb = mask[[5, 3, 2], :, :].transpose((1,2,0))
138
+ meta = get_meta(target_image)
139
+ mask = np.where(mask == meta['nodata'], 1, 0)
140
+ mask = np.max(mask, axis=0)[None]
141
+
142
+ result[0] = np.where(mask == 1, -1, result[0])
143
+
144
+ ##### Save file to disk
145
+ meta["count"] = 1
146
+ meta["dtype"] = "int16"
147
+ meta["compress"] = "lzw"
148
+ meta["nodata"] = -1
149
+ print('Saving output...')
150
+ # write_tiff(result[0], output_image, meta)
151
+ et = time.time()
152
+ time_taken = np.round(et - st, 1)
153
+ print(f'Inference completed in {str(time_taken)} seconds')
154
+
155
+ except:
156
+ print(f'Error on image {target_image} \nContinue to next input')
157
+
158
+ return rgb, result[0][0]*255
159
+
160
+ def process_test_pipeline(custom_test_pipeline, bands=None):
161
+
162
+ # change extracted bands if necessary
163
+ if bands is not None:
164
+
165
+ extract_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'] == 'BandsExtract' ]
166
+
167
+ if len(extract_index) > 0:
168
+
169
+ custom_test_pipeline[extract_index[0]]['bands'] = eval(bands)
170
+
171
+ collect_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'].find('Collect') > -1]
172
+
173
+ # adapt collected keys if necessary
174
+ if len(collect_index) > 0:
175
+
176
+ keys = ['img_info', 'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']
177
+ custom_test_pipeline[collect_index[0]]['meta_keys'] = keys
178
+
179
+ return custom_test_pipeline
180
+
181
+ model = init_segmentor(config_path, ckpt, device='cpu')
182
+ custom_test_pipeline=process_test_pipeline(model.cfg.data.test.pipeline, None)
183
+
184
+ func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline)
185
+
186
+ with gr.Blocks() as demo:
187
+
188
+ with gr.Row():
189
+ with gr.Column():
190
+ inp = gr.File()
191
+ btn = gr.Button("Submit")
192
+ with gr.Row():
193
+ out1=gr.Image(image_mode='RGB')
194
+ out2 = gr.Image(image_mode='L')
195
+
196
+ btn.click(fn=func, inputs=inp, outputs=[out1, out2])
197
+
198
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==1.7.1
2
+ torchvision==0.8.2
3
+ openmim
4
+ gradio