Spaces:
Sleeping
Sleeping
| import base64 | |
| import streamlit as st | |
| import zipfile | |
| from utils import * | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.animation as animation | |
| import streamlit.components.v1 as components | |
| from matplotlib import colors | |
| st.set_page_config(layout="wide") | |
| def create_animation(images, pred_dates): | |
| print('Creating composition of images...') | |
| fps = 2 | |
| fig_an, ax_an = plt.subplots() | |
| plt.title("") | |
| a = images[0] | |
| im = ax_an.imshow(a, interpolation='none', aspect='auto', vmin=0, vmax=1) | |
| title = ax_an.text(0.5, 0.85, "", bbox={'facecolor': 'w', 'alpha': 0.5, 'pad': 5}, | |
| transform=ax_an.transAxes, ha="center") | |
| def animate_func(idx): | |
| title.set_text("date: " + pred_dates[idx]) | |
| im.set_array(images[idx]) | |
| return [im] | |
| anima = animation.FuncAnimation(fig_an, animate_func, frames=len(images), interval=1000 / fps, blit=True, | |
| repeat=False) | |
| print('Done!') | |
| return anima | |
| def load_daily_preds_as_animations(pred_full_paths, pred_dates): | |
| daily_preds = [] | |
| for path in pred_full_paths: | |
| img, _ = read(path) | |
| img = np.squeeze(img) | |
| img = [classes_color_map[p] for p in img] | |
| daily_preds.append(img) | |
| anima = create_animation(daily_preds, pred_dates) | |
| return anima | |
| def load_src_images_as_animations(img_paths, pred_dates): | |
| imgs = [] | |
| for path in img_paths: | |
| img, _ = read(path) | |
| # https://custom-scripts.sentinel-hub.com/custom-scripts/sentinel-2/composites/ | |
| # IREA image: | |
| # False colors (8,4,3): 2,blue-B3,green-B4,5,6,7,red-B8,11,12 | |
| # Simple RGB (4, 3, 2): blue-B2,green-B3,red-B4,5,6,7,8,11,12 | |
| rgb = img[[2, 1, 0], :, :] | |
| rgb = np.moveaxis(rgb, 0, -1) | |
| imgs.append(rgb/np.amax(rgb)) | |
| anima = create_animation(imgs, pred_dates) | |
| return anima | |
| if not hasattr(st, 'paths'): | |
| st.paths = None | |
| if not hasattr(st, 'daily_model'): | |
| best_model_daily_file_name = "best_model_daily.pth" | |
| best_model_annual_file_name = "best_model_annual.pth" | |
| first_input_batch = torch.zeros(71, 9, 5, 48, 48) | |
| # first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:]) | |
| st.daily_model = FPN(opt, first_input_batch, opt.win_size) | |
| st.annual_model = SimpleNN(opt) | |
| if torch.cuda.is_available(): | |
| st.daily_model = torch.nn.DataParallel(st.daily_model).cuda() | |
| st.annual_model = torch.nn.DataParallel(st.annual_model).cuda() | |
| st.daily_model = torch.nn.DataParallel(st.daily_model).cuda() | |
| st.annual_model = torch.nn.DataParallel(st.annual_model).cuda() | |
| else: | |
| st.daily_model = torch.nn.DataParallel(st.daily_model).cpu() | |
| st.annual_model = torch.nn.DataParallel(st.annual_model).cpu() | |
| st.daily_model = torch.nn.DataParallel(st.daily_model).cpu() | |
| st.annual_model = torch.nn.DataParallel(st.annual_model).cpu() | |
| print('trying to resume previous saved models...') | |
| state = resume( | |
| os.path.join(opt.resume_path, best_model_daily_file_name), | |
| model=st.daily_model, optimizer=None) | |
| state = resume( | |
| os.path.join(opt.resume_path, best_model_annual_file_name), | |
| model=st.annual_model, optimizer=None) | |
| st.daily_model = st.daily_model.eval() | |
| st.annual_model = st.annual_model.eval() | |
| # Load Model | |
| # @title Load pretrained weights | |
| st.title('In-season and dynamic crop mapping using 3D convolution neural networks and sentinel-2 time series') | |
| st.markdown(""" Demo App for the model presented in the [paper](https://www.sciencedirect.com/science/article/pii/S0924271622003203): | |
| ``` | |
| @article{gallo2022in_season, | |
| title = {In-season and dynamic crop mapping using 3D convolution neural networks and sentinel-2 time series}, | |
| journal = {ISPRS Journal of Photogrammetry and Remote Sensing}, | |
| volume = {195}, | |
| pages = {335-352}, | |
| year = {2023}, | |
| issn = {0924-2716}, | |
| doi = {https://doi.org/10.1016/j.isprsjprs.2022.12.005}, | |
| url = {https://www.sciencedirect.com/science/article/pii/S0924271622003203}, | |
| author = {Ignazio Gallo and Luigi Ranghetti and Nicola Landro and Riccardo {La Grassa} and Mirco Boschetti}, | |
| } | |
| ``` | |
| **NOTE: The demo doesn't work properly, we are working to fix the bugs!** | |
| """) | |
| file_uploaded = st.file_uploader( | |
| "Upload a zip file containing a sample", | |
| type=["zip"], | |
| accept_multiple_files=False, | |
| ) | |
| sample_path = None | |
| tileids = None | |
| st.paths = None | |
| if file_uploaded is not None: | |
| with zipfile.ZipFile(file_uploaded, "r") as z: | |
| z.extractall(os.path.join("uploaded_samples", opt.years[0])) | |
| tileids = [file_uploaded.name[:-4]] | |
| # sample_path = os.path.join("uploaded_samples", opt.years[0], tileids[0]) | |
| sample_path = "uploaded_samples" | |
| st.markdown('or use a demo sample') | |
| col1, col2, col3, col4 = st.columns([1, 1, 1, 1]) | |
| with col1: | |
| if st.button('sample 1'): | |
| sample_path = 'demo_data/lombardia' | |
| tileids = ['24'] | |
| with col2: | |
| if st.button('sample 2'): | |
| sample_path = 'demo_data/lombardia' | |
| tileids = ['712'] | |
| with col3: | |
| if st.button('sample 3'): | |
| sample_path = 'demo_data/lombardia' | |
| tileids = ['814'] | |
| with col4: | |
| if st.button('sample 4'): | |
| sample_path = 'demo_data/lombardia' | |
| tileids = ['1509'] | |
| # paths = None | |
| if sample_path is not None: | |
| # st.markdown(f'elaborating {sample_path} ...') | |
| validationdataset = SentinelDailyAnnualDatasetNoLabel( | |
| sample_path, | |
| opt.years, | |
| opt.classes_path, | |
| opt.sample_duration, | |
| opt.win_size, | |
| tileids=tileids) | |
| validationdataloader = torch.utils.data.DataLoader( | |
| validationdataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers) | |
| st.markdown('Model prediction in progress ...') | |
| out_dir = os.path.join(opt.result_path, "seg_maps") | |
| if not os.path.exists(out_dir): | |
| os.makedirs(out_dir) | |
| for i, (x_dailies, dates, dirs_path) in enumerate(validationdataloader): | |
| with torch.no_grad(): | |
| # x_dailies, dates, dirs_path = next(iter(validationdataloader)) | |
| # reshape merging the first two dimensions | |
| x_dailies = x_dailies.view(-1, *x_dailies.shape[2:]) | |
| if torch.cuda.is_available(): | |
| x_dailies = x_dailies.cuda() | |
| feat_daily, outs_daily = st.daily_model.forward(x_dailies) | |
| # return to original size of batch and year | |
| outs_daily = outs_daily.view( | |
| opt.batch_size, opt.sample_duration, *outs_daily.shape[1:]) | |
| feat_daily = feat_daily.view( | |
| opt.batch_size, opt.sample_duration, *feat_daily.shape[1:]) | |
| _, out_annual = st.annual_model.forward(feat_daily) | |
| pred_annual = torch.argmax(out_annual, dim=1).squeeze(1) | |
| pred_annual = pred_annual.cpu().numpy() | |
| # Remapping the labels | |
| pred_annual_nn = ids_to_labels( | |
| validationdataloader, pred_annual).astype(numpy.uint8) | |
| for batch in range(feat_daily.shape[0]): | |
| # _, profile = read(os.path.join(dirs_path[batch], '20191230_MSAVI.tif')) # todo get the last image | |
| _, tmp_path = get_patch_id(validationdataset.samples, 0) | |
| dates = get_all_dates( | |
| tmp_path, validationdataset.max_seq_length) | |
| last_tif_path = os.path.join(tmp_path, dates[-1] + ".tif") | |
| _, profile = read(last_tif_path) | |
| profile["name"] = dirs_path[batch] | |
| pth = dirs_path[batch].split(os.path.sep)[-3:] | |
| full_pth_patch = os.path.join( | |
| out_dir, pth[1] + '-' + pth[0], pth[2]) | |
| if not os.path.exists(full_pth_patch): | |
| os.makedirs(full_pth_patch) | |
| full_pth_pred = os.path.join( | |
| full_pth_patch, 'patch-pred-nn.tif') | |
| profile.update({ | |
| 'nodata': None, | |
| 'dtype': 'uint8', | |
| 'count': 1}) | |
| with rasterio.open(full_pth_pred, 'w', **profile) as dst: | |
| dst.write_band(1, pred_annual_nn[batch]) | |
| # patch_predictions = None | |
| for ch in range(len(dates)): | |
| soft_seg = outs_daily[batch, ch, :, :, :] | |
| # transform probs into a hard segmentation | |
| pred_daily = torch.argmax(soft_seg, dim=0) | |
| pred_daily = pred_daily.cpu() | |
| daily_pred = ids_to_labels( | |
| validationdataloader, pred_daily).astype(numpy.uint8) | |
| # if patch_predictions is None: | |
| # patch_predictions = numpy.expand_dims(daily_pred, axis=0) | |
| # else: | |
| # patch_predictions = numpy.concatenate((patch_predictions, numpy.expand_dims(daily_pred, axis=0)), | |
| # axis=0) | |
| # save GT image in opt.root_path | |
| full_pth_date = os.path.join( | |
| full_pth_patch, dates[ch] + '-daily-pred.tif') | |
| profile.update({ | |
| 'nodata': None, | |
| 'dtype': 'uint8', | |
| 'count': 1}) | |
| with rasterio.open(full_pth_date, 'w', **profile) as dst: | |
| dst.write_band(1, daily_pred) | |
| st.markdown('End prediction') | |
| # folder_out = "demo_data/results/seg_maps/example-lombardia/2" | |
| folder_out = full_pth_patch # os.path.join("demo_data/results/seg_maps/"+opt.years[0]+"-lombardia/", tileids[0]) | |
| st.paths = os.listdir(folder_out) | |
| st.paths.sort() | |
| if st.paths is not None: | |
| # folder_out = os.path.join("demo_data/results/seg_maps/example-lombardia/", tileids[0]) | |
| folder_src = os.path.join("demo_data/lombardia/", opt.years[0], tileids[0]) | |
| st.markdown(""" | |
| ### Predictions | |
| """) | |
| # file_picker = st.selectbox("Select day predict (annual is patch-pred-nn.tif)", | |
| # st.paths, index=st.paths.index('patch-pred-nn.tif')) | |
| file_path = os.path.join(folder_out, 'patch-pred-nn.tif') | |
| # print(file_path) | |
| target, profile = read(file_path) | |
| target = np.squeeze(target) | |
| target = [classes_color_map[p] for p in target] | |
| fig, ax = plt.subplots() | |
| ax.imshow(target) | |
| markdown_legend = '' | |
| for c, l in zip(color_labels, labels_map): | |
| # print(colors.to_hex(c)) | |
| markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>' | |
| col1, col2 = st.columns([2,1]) | |
| with col1: | |
| st.markdown("**Long-term (annual) prediction**") | |
| st.pyplot(fig) | |
| with col2: | |
| st.markdown("**Legend**") | |
| st.markdown(markdown_legend, unsafe_allow_html=True) | |
| st.markdown("**Short-term (daily) predictions**") | |
| img_full_paths = [os.path.join(folder_out, path) for path in st.paths if 'daily-pred' in path] | |
| pred_dates = [path[:8] for path in st.paths if 'daily-pred' in path] | |
| anim = load_daily_preds_as_animations(img_full_paths, pred_dates) | |
| components.html(anim.to_jshtml(), height=600) | |
| st.markdown("**Input time series**") | |
| list_dir = os.listdir(folder_src) | |
| list_dir.sort() | |
| img_full_paths = [os.path.join(folder_src, f) for f in list_dir if f.endswith(".tif")] | |
| pred_dates = [f[:8] for f in list_dir if f.endswith(".tif")] | |
| anim_src = load_src_images_as_animations(img_full_paths, pred_dates) | |
| components.html(anim_src.to_jshtml(), height=600) | |
| # zip_url = hf_hub_url(repo_id="ARTeLab/DemoCropMapping", filename="demo_data/1509.zip") | |
| # with open("demo_data/1509.zip", "rb") as f: | |
| # bytes = f.read() | |
| # b64 = base64.b64encode(bytes).decode() | |
| # href = f'<a href="data:file/zip;base64,{b64}" download=\'1509.zip\'>\ | |
| # Click to download\ | |
| # </a>' | |
| # st.sidebar.markdown(href, unsafe_allow_html=True) | |
| # download_button_str = download_button(s, filename, f'Click here to download {filename}') | |
| # st.markdown(download_button_str, unsafe_allow_html=True) | |
| # with open('demo_data/1509.zip') as f: | |
| # st.download_button('Download 1509.zip', f, file_name="demo_data/1509.zip") | |
| st.markdown(f""" | |
| ## Lombardia Dataset | |
| You can download other patches from the original dataset created and published on | |
| [Kaggle](https://www.kaggle.com/datasets/ignazio/sentinel2-crop-mapping) and used in our paper. | |
| ## How to build an input file for the Demo | |
| You can download the following zip example to better understand how to create a new sample to feed as input to the model. """) | |
| with open("demo_data/1509.zip", "rb") as fp: | |
| btn = st.download_button( | |
| label="Download ZIP example", | |
| data=fp, | |
| file_name="1509.zip", | |
| mime="application/octet-stream" | |
| ) | |
| st.markdown(f""" | |
| A sample is a time series of sentinel-2 images, | |
| i.e. all images acquired by the satellite during a year. | |
| A zip file must contain | |
| - a geoTiff image of size _9 x 48 x 48_ for each date of the time series; | |
| - the name of each geoTif must show the date like this example "20221225.tif" which represents the date 25 December 2022; | |
| - each image must contain all sentinel-2 bands as reported in the [paper](https://www.sciencedirect.com/science/article/pii/S0924271622003203); | |
| - all the images inside the zip file must be placed inside a directory (see ZIP example) where the name represents the name of the patch (for example "24"). ) | |
| """) | |