import os import shutil import tempfile import gradio as gr import plotly.graph_objects as go import pandas as pd from time import time from utils import ( create_file_structure, init_info_csv, add_to_info_csv, ) from satseg.dataset import create_datasets, create_inference_dataset from satseg.model import train_model, save_model, run_inference, load_model from satseg.seg_result import combine_seg_maps, get_combined_map_contours from satseg.geo_tools import ( shapefile_to_latlong, shapefile_to_grid_indices, points_to_shapefile, contours_to_shapefile, get_tif_n_channels, ) DATA_DIR = "data" MODEL_DIR = os.path.join(DATA_DIR, "models") TIF_DIR = os.path.join(DATA_DIR, "tifs") MASK_DIR = os.path.join(DATA_DIR, "masks") INFO_DIR = os.path.join(DATA_DIR, "info") MODEL_INFO_PATH = os.path.join(INFO_DIR, "model_data.csv") DATASET_TIF_INFO_PATH = os.path.join(INFO_DIR, "dataset_tif_data.csv") DATASET_MASK_INFO_PATH = os.path.join(INFO_DIR, "dataset_mask_data.csv") create_file_structure( [DATA_DIR, TIF_DIR, MASK_DIR, INFO_DIR], [MODEL_INFO_PATH, DATASET_TIF_INFO_PATH, DATASET_MASK_INFO_PATH], ) init_info_csv( MODEL_INFO_PATH, [ "Name", "Architecture", "# of channels", "Train TIF", "Train Mask", "Expression", "Path", ], ) init_info_csv(DATASET_TIF_INFO_PATH, ["Name", "# of channels", "Path"]) init_info_csv(DATASET_MASK_INFO_PATH, ["Name", "Class", "Path"]) def gr_train_model( tif_names, mask_names, model_name, expression, progress=gr.Progress() ): tif_paths = list(map(lambda x: os.path.join(TIF_DIR, x), tif_names)) mask_paths = list(map(lambda x: os.path.join(MASK_DIR, x), mask_names)) expression = expression.strip().split() # if arch.lower() == "best": # arch = "dcama" if len(train_set) > 8 and len(train_set) < 20 else "unet" # ( c6 - c0 ) / ( c6 + c0 ) = progress(0, desc="Creating Dataset...") with tempfile.TemporaryDirectory() as tempdir: train_set, val_set = create_datasets( tif_paths, mask_paths, tempdir, expression=expression ) progress(0.05, desc="Training Model...") model, _ = train_model(train_set, val_set, "unet") progress(0.95, desc="Model Trained! Saving...") model_name = "_".join(model_name.split()) + ".pt" model_path = os.path.join(MODEL_DIR, model_name) save_model(model, model_path) add_to_info_csv( MODEL_INFO_PATH, [ model_name, "UNet", val_set.n_channels, ";".join(tif_names), ";".join(mask_names), " ".join(expression), model_path, ], ) progress(1.0, desc="Done!") model_df = pd.read_csv(MODEL_INFO_PATH) return "Done!", model_df, gr.Dropdown.update(choices=model_df["Name"].to_list()) def gr_run_inference(tif_names, model_name, progress=gr.Progress()): t = time() tif_paths = list(map(lambda x: os.path.join(TIF_DIR, x), tif_names)) model_df = pd.read_csv(MODEL_INFO_PATH, index_col="Name") model_path = model_df["Path"][model_name] with tempfile.TemporaryDirectory() as tempdir: progress(0, desc="Creating Dataset...") dataset = create_inference_dataset( tif_paths, tempdir, 256, expression=model_df["Expression"][model_name].split(), ) progress(0.1, desc="Loading Model...") model = load_model(model_path) result_dir = os.path.join(tempdir, "infer") comb_result_dir = os.path.join(tempdir, "comb") os.makedirs(result_dir) os.makedirs(comb_result_dir) progress(0.2, desc="Running Inference...") run_inference(dataset, model, result_dir) progress(0.8, desc="Preparing output...") combine_seg_maps(result_dir, comb_result_dir) results = get_combined_map_contours(comb_result_dir) file_paths = [] out_dir = os.path.join(MASK_DIR, "output") if os.path.exists(out_dir): shutil.rmtree(out_dir) os.makedirs(out_dir) for tif_name, (contours, hierarchy) in results.items(): tif_path = os.path.join(TIF_DIR, f"{tif_name}.tif") mask_path = os.path.join(out_dir, f"{tif_name}_mask.shp") zip_path = contours_to_shapefile(contours, hierarchy, tif_path, mask_path) file_paths.append(zip_path) print(time() - t, "seconds") return file_paths def gr_save_mask_file(file_objs, filenames, obj_class): print("Saving file(s)...") idx = 0 for filename in filenames.split(";"): if filename.strip() == "": continue filepath = os.path.join(MASK_DIR, filename.strip()) obj = file_objs[idx] idx += 1 shutil.move(obj.name, filepath) if filename.endswith(".shp"): add_to_info_csv(DATASET_MASK_INFO_PATH, [filename, obj_class, filepath]) print("Done!") dataset_df = pd.read_csv(DATASET_MASK_INFO_PATH) choices = dataset_mask_df["Name"].to_list() update = gr.Dropdown.update(choices=choices) return dataset_df, update, update def gr_save_tif_file(file_objs, filenames): print("Saving file(s)...") idx = 0 for filename in filenames.split(";"): if filename.strip() == "": continue filepath = os.path.join(TIF_DIR, filename.strip()) obj = file_objs[idx] idx += 1 shutil.copy2(obj.name, filepath) n = get_tif_n_channels(filepath) add_to_info_csv(DATASET_TIF_INFO_PATH, [filename, n, filepath]) print("Done!") dataset_df = pd.read_csv(DATASET_TIF_INFO_PATH) choices = dataset_mask_df["Name"].to_list() update = gr.Dropdown.update(choices=choices) return dataset_df, update, update def gr_generate_map(mask_name: str, token: str = "", show_grid=True, show_mask=False): mask_path = os.path.join(MASK_DIR, mask_name) # token = "pk.eyJ1IjoiZGlsaXRoIiwiYSI6ImNsaDQ3NXF3ZDAxdDMzZXMxeWJic2h1cDQifQ.DDczQCDfTgQEUt6pGvjUAg" center = (7.753769, 80.691730) scattermaps = [] if show_grid: indices = shapefile_to_grid_indices(mask_path) points_to_shapefile(indices, mask_path[: -len(".shp")] + "-grid.shp") scattermaps.append( go.Scattermapbox( lat=indices[:, 1], lon=indices[:, 0], mode="markers", marker=go.scattermapbox.Marker(size=6), ) ) if show_mask: contours = shapefile_to_latlong(mask_path) for contour in contours[38:39]: lons = contour[:, 0] lats = contour[:, 1] scattermaps.append( go.Scattermapbox( fill="toself", lat=lats, lon=lons, mode="markers", marker=go.scattermapbox.Marker(size=6), ) ) fig = go.Figure(scattermaps) if token: fig.update_layout( mapbox=dict( style="satellite-streets", accesstoken=token, center=go.layout.mapbox.Center(lat=center[0], lon=center[1]), pitch=0, zoom=7, ), mapbox_layers=[ { # "below": "traces", "sourcetype": "raster", "sourceattribution": "United States Geological Survey", "source": [ "https://basemap.nationalmap.gov/arcgis/rest/services/USGSImageryOnly/MapServer/tile/{z}/{y}/{x}" ], } ], ) else: fig.update_layout( mapbox_style="open-street-map", hovermode="closest", mapbox=dict( bearing=0, center=go.layout.mapbox.Center(lat=center[0], lon=center[1]), pitch=0, zoom=7, ), ) return fig with gr.Blocks() as demo: gr.Markdown( """# SatSeg Train models and run inference for segmentation of multispectral satellite images.""" ) model_df = pd.read_csv(MODEL_INFO_PATH) dataset_tif_df = pd.read_csv(DATASET_TIF_INFO_PATH) dataset_mask_df = pd.read_csv(DATASET_MASK_INFO_PATH) with gr.Tab("Train"): train_tif_names = gr.Dropdown( label="TIF Files", choices=dataset_tif_df["Name"].to_list(), multiselect=True, ) train_mask_names = gr.Dropdown( label="Mask files", choices=dataset_mask_df["Name"].to_list(), multiselect=True, ) train_rs_index = gr.Textbox( label="Remote Sensing Index", placeholder="( c0 + c1 ) / ( c0 - c1 ) =" ) # train_arch = gr.Dropdown( # label="Model Architecture", choices=["Best", "UNet", "DCAMA"], value="Best" # ) train_model_name = gr.Textbox( label="Model Name", placeholder="Give the model a name" ) train_button = gr.Button("Train") train_completion = gr.Text(label="Training Status", value="Not Started") with gr.Tab("Infer"): infer_tif_names = gr.Dropdown( label="TIF Files", choices=dataset_tif_df["Name"].to_list(), multiselect=True, ) infer_model_name = gr.Dropdown( label="Model Name", choices=model_df["Name"].to_list(), ) infer_button = gr.Button("Infer") infer_mask = gr.Files(label="Output Shapefile", interactive=False) # with gr.Tab("Sampling"): # grid_mask_name = gr.Dropdown( # label="Mask", # choices=dataset_mask_df["Name"].to_list(), # ) # grid_token = gr.Textbox( # value="", label="Mapbox Token (https://account.mapbox.com/)" # ) # grid_side_len = gr.Textbox(value="100", label="Sampling Gap (m)") # grid_show_grid = gr.Checkbox(True, label="Show Grid") # grid_show_mask = gr.Checkbox(False, label="Show Mask") # grid_button = gr.Button("Generate Grid") # grid_map = gr.Plot(label="Plot") with gr.Tab("Datasets"): dataset_tif_df = pd.read_csv(DATASET_TIF_INFO_PATH) dataset_mask_df = pd.read_csv(DATASET_MASK_INFO_PATH) datasets_upload_tif = gr.File(label="Images (.tif)", file_count="multiple") datasets_upload_tif_name = gr.Textbox( label="TIF name", placeholder="tif_file_1.tif;tif_file_2.tif" ) datasets_save_uploaded_tif = gr.Button("Save") datasets_upload_mask = gr.File( label="Masks (Please upload all extensions (.shp, .shx, etc.))", file_count="multiple", ) datasets_upload_mask_name = gr.Textbox( label="Mask name", placeholder="mask_1.shp;mask_1.shx" ) datasets_mask_class_name = gr.Textbox( label="Class (The name of the object you want to segment)" ) datasets_save_uploaded_mask = gr.Button("Save") datasets_tif_table = gr.Dataframe(dataset_tif_df, label="TIFs") datasets_mask_table = gr.Dataframe(dataset_mask_df, label="Masks") with gr.Tab("Models"): models_table = gr.Dataframe(model_df) train_button.click( gr_train_model, inputs=[ train_tif_names, train_mask_names, # train_arch, train_model_name, train_rs_index, ], outputs=[train_completion, models_table, infer_model_name], ) infer_button.click( gr_run_inference, inputs=[infer_tif_names, infer_model_name], outputs=[infer_mask], ) datasets_upload_tif.upload( lambda y: ";".join(list(map(lambda x: os.path.basename(x.orig_name), y))), inputs=datasets_upload_tif, outputs=datasets_upload_tif_name, ) datasets_upload_mask.upload( lambda y: ";".join(list(map(lambda x: os.path.basename(x.orig_name), y))), inputs=datasets_upload_mask, outputs=datasets_upload_mask_name, ) # grid_button.click( # gr_generate_map, # inputs=[grid_mask_name, grid_token, grid_show_grid, grid_show_mask], # outputs=grid_map, # ) datasets_save_uploaded_tif.click( gr_save_tif_file, inputs=[datasets_upload_tif, datasets_upload_tif_name], outputs=[datasets_tif_table, train_tif_names, infer_tif_names], ) datasets_save_uploaded_mask.click( gr_save_mask_file, inputs=[ datasets_upload_mask, datasets_upload_mask_name, datasets_mask_class_name, ], outputs=[datasets_mask_table, train_mask_names], ) demo.queue(concurrency_count=10).launch(debug=True)