|
import os |
|
import shutil |
|
import tempfile |
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
|
|
import pandas as pd |
|
from time import time |
|
from utils import ( |
|
create_file_structure, |
|
init_info_csv, |
|
add_to_info_csv, |
|
) |
|
|
|
from satseg.dataset import create_datasets, create_inference_dataset |
|
from satseg.model import train_model, save_model, run_inference, load_model |
|
from satseg.seg_result import combine_seg_maps, get_combined_map_contours |
|
from satseg.geo_tools import ( |
|
shapefile_to_latlong, |
|
shapefile_to_grid_indices, |
|
points_to_shapefile, |
|
contours_to_shapefile, |
|
get_tif_n_channels, |
|
) |
|
|
|
DATA_DIR = "data" |
|
MODEL_DIR = os.path.join(DATA_DIR, "models") |
|
TIF_DIR = os.path.join(DATA_DIR, "tifs") |
|
MASK_DIR = os.path.join(DATA_DIR, "masks") |
|
INFO_DIR = os.path.join(DATA_DIR, "info") |
|
|
|
MODEL_INFO_PATH = os.path.join(INFO_DIR, "model_data.csv") |
|
DATASET_TIF_INFO_PATH = os.path.join(INFO_DIR, "dataset_tif_data.csv") |
|
DATASET_MASK_INFO_PATH = os.path.join(INFO_DIR, "dataset_mask_data.csv") |
|
|
|
create_file_structure( |
|
[DATA_DIR, TIF_DIR, MASK_DIR, INFO_DIR], |
|
[MODEL_INFO_PATH, DATASET_TIF_INFO_PATH, DATASET_MASK_INFO_PATH], |
|
) |
|
init_info_csv( |
|
MODEL_INFO_PATH, |
|
[ |
|
"Name", |
|
"Architecture", |
|
"# of channels", |
|
"Train TIF", |
|
"Train Mask", |
|
"Expression", |
|
"Path", |
|
], |
|
) |
|
init_info_csv(DATASET_TIF_INFO_PATH, ["Name", "# of channels", "Path"]) |
|
init_info_csv(DATASET_MASK_INFO_PATH, ["Name", "Class", "Path"]) |
|
|
|
|
|
def gr_train_model( |
|
tif_names, mask_names, model_name, expression, progress=gr.Progress() |
|
): |
|
tif_paths = list(map(lambda x: os.path.join(TIF_DIR, x), tif_names)) |
|
mask_paths = list(map(lambda x: os.path.join(MASK_DIR, x), mask_names)) |
|
expression = expression.strip().split() |
|
|
|
|
|
|
|
|
|
progress(0, desc="Creating Dataset...") |
|
with tempfile.TemporaryDirectory() as tempdir: |
|
train_set, val_set = create_datasets( |
|
tif_paths, mask_paths, tempdir, expression=expression |
|
) |
|
progress(0.05, desc="Training Model...") |
|
model, _ = train_model(train_set, val_set, "unet") |
|
|
|
progress(0.95, desc="Model Trained! Saving...") |
|
model_name = "_".join(model_name.split()) + ".pt" |
|
model_path = os.path.join(MODEL_DIR, model_name) |
|
save_model(model, model_path) |
|
add_to_info_csv( |
|
MODEL_INFO_PATH, |
|
[ |
|
model_name, |
|
"UNet", |
|
val_set.n_channels, |
|
";".join(tif_names), |
|
";".join(mask_names), |
|
" ".join(expression), |
|
model_path, |
|
], |
|
) |
|
progress(1.0, desc="Done!") |
|
model_df = pd.read_csv(MODEL_INFO_PATH) |
|
|
|
return "Done!", model_df, gr.Dropdown.update(choices=model_df["Name"].to_list()) |
|
|
|
|
|
def gr_run_inference(tif_names, model_name, progress=gr.Progress()): |
|
t = time() |
|
tif_paths = list(map(lambda x: os.path.join(TIF_DIR, x), tif_names)) |
|
model_df = pd.read_csv(MODEL_INFO_PATH, index_col="Name") |
|
model_path = model_df["Path"][model_name] |
|
|
|
with tempfile.TemporaryDirectory() as tempdir: |
|
progress(0, desc="Creating Dataset...") |
|
dataset = create_inference_dataset( |
|
tif_paths, |
|
tempdir, |
|
256, |
|
expression=model_df["Expression"][model_name].split(), |
|
) |
|
progress(0.1, desc="Loading Model...") |
|
model = load_model(model_path) |
|
|
|
result_dir = os.path.join(tempdir, "infer") |
|
comb_result_dir = os.path.join(tempdir, "comb") |
|
os.makedirs(result_dir) |
|
os.makedirs(comb_result_dir) |
|
progress(0.2, desc="Running Inference...") |
|
run_inference(dataset, model, result_dir) |
|
progress(0.8, desc="Preparing output...") |
|
combine_seg_maps(result_dir, comb_result_dir) |
|
results = get_combined_map_contours(comb_result_dir) |
|
|
|
file_paths = [] |
|
out_dir = os.path.join(MASK_DIR, "output") |
|
if os.path.exists(out_dir): |
|
shutil.rmtree(out_dir) |
|
os.makedirs(out_dir) |
|
for tif_name, (contours, hierarchy) in results.items(): |
|
tif_path = os.path.join(TIF_DIR, f"{tif_name}.tif") |
|
mask_path = os.path.join(out_dir, f"{tif_name}_mask.shp") |
|
zip_path = contours_to_shapefile(contours, hierarchy, tif_path, mask_path) |
|
file_paths.append(zip_path) |
|
print(time() - t, "seconds") |
|
return file_paths |
|
|
|
|
|
def gr_save_mask_file(file_objs, filenames, obj_class): |
|
print("Saving file(s)...") |
|
idx = 0 |
|
for filename in filenames.split(";"): |
|
if filename.strip() == "": |
|
continue |
|
|
|
filepath = os.path.join(MASK_DIR, filename.strip()) |
|
obj = file_objs[idx] |
|
idx += 1 |
|
|
|
shutil.move(obj.name, filepath) |
|
if filename.endswith(".shp"): |
|
add_to_info_csv(DATASET_MASK_INFO_PATH, [filename, obj_class, filepath]) |
|
print("Done!") |
|
|
|
dataset_df = pd.read_csv(DATASET_MASK_INFO_PATH) |
|
choices = dataset_mask_df["Name"].to_list() |
|
update = gr.Dropdown.update(choices=choices) |
|
|
|
return dataset_df, update, update |
|
|
|
|
|
def gr_save_tif_file(file_objs, filenames): |
|
print("Saving file(s)...") |
|
idx = 0 |
|
for filename in filenames.split(";"): |
|
if filename.strip() == "": |
|
continue |
|
|
|
filepath = os.path.join(TIF_DIR, filename.strip()) |
|
obj = file_objs[idx] |
|
idx += 1 |
|
|
|
shutil.copy2(obj.name, filepath) |
|
n = get_tif_n_channels(filepath) |
|
add_to_info_csv(DATASET_TIF_INFO_PATH, [filename, n, filepath]) |
|
print("Done!") |
|
|
|
dataset_df = pd.read_csv(DATASET_TIF_INFO_PATH) |
|
choices = dataset_mask_df["Name"].to_list() |
|
update = gr.Dropdown.update(choices=choices) |
|
|
|
return dataset_df, update, update |
|
|
|
|
|
def gr_generate_map(mask_name: str, token: str = "", show_grid=True, show_mask=False): |
|
mask_path = os.path.join(MASK_DIR, mask_name) |
|
|
|
center = (7.753769, 80.691730) |
|
|
|
scattermaps = [] |
|
if show_grid: |
|
indices = shapefile_to_grid_indices(mask_path) |
|
points_to_shapefile(indices, mask_path[: -len(".shp")] + "-grid.shp") |
|
scattermaps.append( |
|
go.Scattermapbox( |
|
lat=indices[:, 1], |
|
lon=indices[:, 0], |
|
mode="markers", |
|
marker=go.scattermapbox.Marker(size=6), |
|
) |
|
) |
|
if show_mask: |
|
contours = shapefile_to_latlong(mask_path) |
|
for contour in contours[38:39]: |
|
lons = contour[:, 0] |
|
lats = contour[:, 1] |
|
scattermaps.append( |
|
go.Scattermapbox( |
|
fill="toself", |
|
lat=lats, |
|
lon=lons, |
|
mode="markers", |
|
marker=go.scattermapbox.Marker(size=6), |
|
) |
|
) |
|
|
|
fig = go.Figure(scattermaps) |
|
|
|
if token: |
|
fig.update_layout( |
|
mapbox=dict( |
|
style="satellite-streets", |
|
accesstoken=token, |
|
center=go.layout.mapbox.Center(lat=center[0], lon=center[1]), |
|
pitch=0, |
|
zoom=7, |
|
), |
|
mapbox_layers=[ |
|
{ |
|
|
|
"sourcetype": "raster", |
|
"sourceattribution": "United States Geological Survey", |
|
"source": [ |
|
"https://basemap.nationalmap.gov/arcgis/rest/services/USGSImageryOnly/MapServer/tile/{z}/{y}/{x}" |
|
], |
|
} |
|
], |
|
) |
|
else: |
|
fig.update_layout( |
|
mapbox_style="open-street-map", |
|
hovermode="closest", |
|
mapbox=dict( |
|
bearing=0, |
|
center=go.layout.mapbox.Center(lat=center[0], lon=center[1]), |
|
pitch=0, |
|
zoom=7, |
|
), |
|
) |
|
|
|
return fig |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
"""# SatSeg |
|
Train models and run inference for segmentation of multispectral satellite images.""" |
|
) |
|
|
|
model_df = pd.read_csv(MODEL_INFO_PATH) |
|
dataset_tif_df = pd.read_csv(DATASET_TIF_INFO_PATH) |
|
dataset_mask_df = pd.read_csv(DATASET_MASK_INFO_PATH) |
|
|
|
with gr.Tab("Train"): |
|
train_tif_names = gr.Dropdown( |
|
label="TIF Files", |
|
choices=dataset_tif_df["Name"].to_list(), |
|
multiselect=True, |
|
) |
|
train_mask_names = gr.Dropdown( |
|
label="Mask files", |
|
choices=dataset_mask_df["Name"].to_list(), |
|
multiselect=True, |
|
) |
|
train_rs_index = gr.Textbox( |
|
label="Remote Sensing Index", placeholder="( c0 + c1 ) / ( c0 - c1 ) =" |
|
) |
|
|
|
|
|
|
|
train_model_name = gr.Textbox( |
|
label="Model Name", placeholder="Give the model a name" |
|
) |
|
train_button = gr.Button("Train") |
|
|
|
train_completion = gr.Text(label="Training Status", value="Not Started") |
|
|
|
with gr.Tab("Infer"): |
|
infer_tif_names = gr.Dropdown( |
|
label="TIF Files", |
|
choices=dataset_tif_df["Name"].to_list(), |
|
multiselect=True, |
|
) |
|
infer_model_name = gr.Dropdown( |
|
label="Model Name", |
|
choices=model_df["Name"].to_list(), |
|
) |
|
infer_button = gr.Button("Infer") |
|
|
|
infer_mask = gr.Files(label="Output Shapefile", interactive=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("Datasets"): |
|
dataset_tif_df = pd.read_csv(DATASET_TIF_INFO_PATH) |
|
dataset_mask_df = pd.read_csv(DATASET_MASK_INFO_PATH) |
|
|
|
datasets_upload_tif = gr.File(label="Images (.tif)", file_count="multiple") |
|
datasets_upload_tif_name = gr.Textbox( |
|
label="TIF name", placeholder="tif_file_1.tif;tif_file_2.tif" |
|
) |
|
datasets_save_uploaded_tif = gr.Button("Save") |
|
|
|
datasets_upload_mask = gr.File( |
|
label="Masks (Please upload all extensions (.shp, .shx, etc.))", |
|
file_count="multiple", |
|
) |
|
datasets_upload_mask_name = gr.Textbox( |
|
label="Mask name", placeholder="mask_1.shp;mask_1.shx" |
|
) |
|
datasets_mask_class_name = gr.Textbox( |
|
label="Class (The name of the object you want to segment)" |
|
) |
|
datasets_save_uploaded_mask = gr.Button("Save") |
|
|
|
datasets_tif_table = gr.Dataframe(dataset_tif_df, label="TIFs") |
|
datasets_mask_table = gr.Dataframe(dataset_mask_df, label="Masks") |
|
|
|
with gr.Tab("Models"): |
|
models_table = gr.Dataframe(model_df) |
|
|
|
train_button.click( |
|
gr_train_model, |
|
inputs=[ |
|
train_tif_names, |
|
train_mask_names, |
|
|
|
train_model_name, |
|
train_rs_index, |
|
], |
|
outputs=[train_completion, models_table, infer_model_name], |
|
) |
|
|
|
infer_button.click( |
|
gr_run_inference, |
|
inputs=[infer_tif_names, infer_model_name], |
|
outputs=[infer_mask], |
|
) |
|
|
|
datasets_upload_tif.upload( |
|
lambda y: ";".join(list(map(lambda x: os.path.basename(x.orig_name), y))), |
|
inputs=datasets_upload_tif, |
|
outputs=datasets_upload_tif_name, |
|
) |
|
|
|
datasets_upload_mask.upload( |
|
lambda y: ";".join(list(map(lambda x: os.path.basename(x.orig_name), y))), |
|
inputs=datasets_upload_mask, |
|
outputs=datasets_upload_mask_name, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets_save_uploaded_tif.click( |
|
gr_save_tif_file, |
|
inputs=[datasets_upload_tif, datasets_upload_tif_name], |
|
outputs=[datasets_tif_table, train_tif_names, infer_tif_names], |
|
) |
|
datasets_save_uploaded_mask.click( |
|
gr_save_mask_file, |
|
inputs=[ |
|
datasets_upload_mask, |
|
datasets_upload_mask_name, |
|
datasets_mask_class_name, |
|
], |
|
outputs=[datasets_mask_table, train_mask_names], |
|
) |
|
|
|
demo.queue(concurrency_count=10).launch(debug=True) |
|
|