SatSeg / app.py
dilithjay's picture
Initial Commit
e87025c
import os
import shutil
import tempfile
import gradio as gr
import plotly.graph_objects as go
import pandas as pd
from time import time
from utils import (
create_file_structure,
init_info_csv,
add_to_info_csv,
)
from satseg.dataset import create_datasets, create_inference_dataset
from satseg.model import train_model, save_model, run_inference, load_model
from satseg.seg_result import combine_seg_maps, get_combined_map_contours
from satseg.geo_tools import (
shapefile_to_latlong,
shapefile_to_grid_indices,
points_to_shapefile,
contours_to_shapefile,
get_tif_n_channels,
)
DATA_DIR = "data"
MODEL_DIR = os.path.join(DATA_DIR, "models")
TIF_DIR = os.path.join(DATA_DIR, "tifs")
MASK_DIR = os.path.join(DATA_DIR, "masks")
INFO_DIR = os.path.join(DATA_DIR, "info")
MODEL_INFO_PATH = os.path.join(INFO_DIR, "model_data.csv")
DATASET_TIF_INFO_PATH = os.path.join(INFO_DIR, "dataset_tif_data.csv")
DATASET_MASK_INFO_PATH = os.path.join(INFO_DIR, "dataset_mask_data.csv")
create_file_structure(
[DATA_DIR, TIF_DIR, MASK_DIR, INFO_DIR],
[MODEL_INFO_PATH, DATASET_TIF_INFO_PATH, DATASET_MASK_INFO_PATH],
)
init_info_csv(
MODEL_INFO_PATH,
[
"Name",
"Architecture",
"# of channels",
"Train TIF",
"Train Mask",
"Expression",
"Path",
],
)
init_info_csv(DATASET_TIF_INFO_PATH, ["Name", "# of channels", "Path"])
init_info_csv(DATASET_MASK_INFO_PATH, ["Name", "Class", "Path"])
def gr_train_model(
tif_names, mask_names, model_name, expression, progress=gr.Progress()
):
tif_paths = list(map(lambda x: os.path.join(TIF_DIR, x), tif_names))
mask_paths = list(map(lambda x: os.path.join(MASK_DIR, x), mask_names))
expression = expression.strip().split()
# if arch.lower() == "best":
# arch = "dcama" if len(train_set) > 8 and len(train_set) < 20 else "unet"
# ( c6 - c0 ) / ( c6 + c0 ) =
progress(0, desc="Creating Dataset...")
with tempfile.TemporaryDirectory() as tempdir:
train_set, val_set = create_datasets(
tif_paths, mask_paths, tempdir, expression=expression
)
progress(0.05, desc="Training Model...")
model, _ = train_model(train_set, val_set, "unet")
progress(0.95, desc="Model Trained! Saving...")
model_name = "_".join(model_name.split()) + ".pt"
model_path = os.path.join(MODEL_DIR, model_name)
save_model(model, model_path)
add_to_info_csv(
MODEL_INFO_PATH,
[
model_name,
"UNet",
val_set.n_channels,
";".join(tif_names),
";".join(mask_names),
" ".join(expression),
model_path,
],
)
progress(1.0, desc="Done!")
model_df = pd.read_csv(MODEL_INFO_PATH)
return "Done!", model_df, gr.Dropdown.update(choices=model_df["Name"].to_list())
def gr_run_inference(tif_names, model_name, progress=gr.Progress()):
t = time()
tif_paths = list(map(lambda x: os.path.join(TIF_DIR, x), tif_names))
model_df = pd.read_csv(MODEL_INFO_PATH, index_col="Name")
model_path = model_df["Path"][model_name]
with tempfile.TemporaryDirectory() as tempdir:
progress(0, desc="Creating Dataset...")
dataset = create_inference_dataset(
tif_paths,
tempdir,
256,
expression=model_df["Expression"][model_name].split(),
)
progress(0.1, desc="Loading Model...")
model = load_model(model_path)
result_dir = os.path.join(tempdir, "infer")
comb_result_dir = os.path.join(tempdir, "comb")
os.makedirs(result_dir)
os.makedirs(comb_result_dir)
progress(0.2, desc="Running Inference...")
run_inference(dataset, model, result_dir)
progress(0.8, desc="Preparing output...")
combine_seg_maps(result_dir, comb_result_dir)
results = get_combined_map_contours(comb_result_dir)
file_paths = []
out_dir = os.path.join(MASK_DIR, "output")
if os.path.exists(out_dir):
shutil.rmtree(out_dir)
os.makedirs(out_dir)
for tif_name, (contours, hierarchy) in results.items():
tif_path = os.path.join(TIF_DIR, f"{tif_name}.tif")
mask_path = os.path.join(out_dir, f"{tif_name}_mask.shp")
zip_path = contours_to_shapefile(contours, hierarchy, tif_path, mask_path)
file_paths.append(zip_path)
print(time() - t, "seconds")
return file_paths
def gr_save_mask_file(file_objs, filenames, obj_class):
print("Saving file(s)...")
idx = 0
for filename in filenames.split(";"):
if filename.strip() == "":
continue
filepath = os.path.join(MASK_DIR, filename.strip())
obj = file_objs[idx]
idx += 1
shutil.move(obj.name, filepath)
if filename.endswith(".shp"):
add_to_info_csv(DATASET_MASK_INFO_PATH, [filename, obj_class, filepath])
print("Done!")
dataset_df = pd.read_csv(DATASET_MASK_INFO_PATH)
choices = dataset_mask_df["Name"].to_list()
update = gr.Dropdown.update(choices=choices)
return dataset_df, update, update
def gr_save_tif_file(file_objs, filenames):
print("Saving file(s)...")
idx = 0
for filename in filenames.split(";"):
if filename.strip() == "":
continue
filepath = os.path.join(TIF_DIR, filename.strip())
obj = file_objs[idx]
idx += 1
shutil.copy2(obj.name, filepath)
n = get_tif_n_channels(filepath)
add_to_info_csv(DATASET_TIF_INFO_PATH, [filename, n, filepath])
print("Done!")
dataset_df = pd.read_csv(DATASET_TIF_INFO_PATH)
choices = dataset_mask_df["Name"].to_list()
update = gr.Dropdown.update(choices=choices)
return dataset_df, update, update
def gr_generate_map(mask_name: str, token: str = "", show_grid=True, show_mask=False):
mask_path = os.path.join(MASK_DIR, mask_name)
# token = "pk.eyJ1IjoiZGlsaXRoIiwiYSI6ImNsaDQ3NXF3ZDAxdDMzZXMxeWJic2h1cDQifQ.DDczQCDfTgQEUt6pGvjUAg"
center = (7.753769, 80.691730)
scattermaps = []
if show_grid:
indices = shapefile_to_grid_indices(mask_path)
points_to_shapefile(indices, mask_path[: -len(".shp")] + "-grid.shp")
scattermaps.append(
go.Scattermapbox(
lat=indices[:, 1],
lon=indices[:, 0],
mode="markers",
marker=go.scattermapbox.Marker(size=6),
)
)
if show_mask:
contours = shapefile_to_latlong(mask_path)
for contour in contours[38:39]:
lons = contour[:, 0]
lats = contour[:, 1]
scattermaps.append(
go.Scattermapbox(
fill="toself",
lat=lats,
lon=lons,
mode="markers",
marker=go.scattermapbox.Marker(size=6),
)
)
fig = go.Figure(scattermaps)
if token:
fig.update_layout(
mapbox=dict(
style="satellite-streets",
accesstoken=token,
center=go.layout.mapbox.Center(lat=center[0], lon=center[1]),
pitch=0,
zoom=7,
),
mapbox_layers=[
{
# "below": "traces",
"sourcetype": "raster",
"sourceattribution": "United States Geological Survey",
"source": [
"https://basemap.nationalmap.gov/arcgis/rest/services/USGSImageryOnly/MapServer/tile/{z}/{y}/{x}"
],
}
],
)
else:
fig.update_layout(
mapbox_style="open-street-map",
hovermode="closest",
mapbox=dict(
bearing=0,
center=go.layout.mapbox.Center(lat=center[0], lon=center[1]),
pitch=0,
zoom=7,
),
)
return fig
with gr.Blocks() as demo:
gr.Markdown(
"""# SatSeg
Train models and run inference for segmentation of multispectral satellite images."""
)
model_df = pd.read_csv(MODEL_INFO_PATH)
dataset_tif_df = pd.read_csv(DATASET_TIF_INFO_PATH)
dataset_mask_df = pd.read_csv(DATASET_MASK_INFO_PATH)
with gr.Tab("Train"):
train_tif_names = gr.Dropdown(
label="TIF Files",
choices=dataset_tif_df["Name"].to_list(),
multiselect=True,
)
train_mask_names = gr.Dropdown(
label="Mask files",
choices=dataset_mask_df["Name"].to_list(),
multiselect=True,
)
train_rs_index = gr.Textbox(
label="Remote Sensing Index", placeholder="( c0 + c1 ) / ( c0 - c1 ) ="
)
# train_arch = gr.Dropdown(
# label="Model Architecture", choices=["Best", "UNet", "DCAMA"], value="Best"
# )
train_model_name = gr.Textbox(
label="Model Name", placeholder="Give the model a name"
)
train_button = gr.Button("Train")
train_completion = gr.Text(label="Training Status", value="Not Started")
with gr.Tab("Infer"):
infer_tif_names = gr.Dropdown(
label="TIF Files",
choices=dataset_tif_df["Name"].to_list(),
multiselect=True,
)
infer_model_name = gr.Dropdown(
label="Model Name",
choices=model_df["Name"].to_list(),
)
infer_button = gr.Button("Infer")
infer_mask = gr.Files(label="Output Shapefile", interactive=False)
# with gr.Tab("Sampling"):
# grid_mask_name = gr.Dropdown(
# label="Mask",
# choices=dataset_mask_df["Name"].to_list(),
# )
# grid_token = gr.Textbox(
# value="", label="Mapbox Token (https://account.mapbox.com/)"
# )
# grid_side_len = gr.Textbox(value="100", label="Sampling Gap (m)")
# grid_show_grid = gr.Checkbox(True, label="Show Grid")
# grid_show_mask = gr.Checkbox(False, label="Show Mask")
# grid_button = gr.Button("Generate Grid")
# grid_map = gr.Plot(label="Plot")
with gr.Tab("Datasets"):
dataset_tif_df = pd.read_csv(DATASET_TIF_INFO_PATH)
dataset_mask_df = pd.read_csv(DATASET_MASK_INFO_PATH)
datasets_upload_tif = gr.File(label="Images (.tif)", file_count="multiple")
datasets_upload_tif_name = gr.Textbox(
label="TIF name", placeholder="tif_file_1.tif;tif_file_2.tif"
)
datasets_save_uploaded_tif = gr.Button("Save")
datasets_upload_mask = gr.File(
label="Masks (Please upload all extensions (.shp, .shx, etc.))",
file_count="multiple",
)
datasets_upload_mask_name = gr.Textbox(
label="Mask name", placeholder="mask_1.shp;mask_1.shx"
)
datasets_mask_class_name = gr.Textbox(
label="Class (The name of the object you want to segment)"
)
datasets_save_uploaded_mask = gr.Button("Save")
datasets_tif_table = gr.Dataframe(dataset_tif_df, label="TIFs")
datasets_mask_table = gr.Dataframe(dataset_mask_df, label="Masks")
with gr.Tab("Models"):
models_table = gr.Dataframe(model_df)
train_button.click(
gr_train_model,
inputs=[
train_tif_names,
train_mask_names,
# train_arch,
train_model_name,
train_rs_index,
],
outputs=[train_completion, models_table, infer_model_name],
)
infer_button.click(
gr_run_inference,
inputs=[infer_tif_names, infer_model_name],
outputs=[infer_mask],
)
datasets_upload_tif.upload(
lambda y: ";".join(list(map(lambda x: os.path.basename(x.orig_name), y))),
inputs=datasets_upload_tif,
outputs=datasets_upload_tif_name,
)
datasets_upload_mask.upload(
lambda y: ";".join(list(map(lambda x: os.path.basename(x.orig_name), y))),
inputs=datasets_upload_mask,
outputs=datasets_upload_mask_name,
)
# grid_button.click(
# gr_generate_map,
# inputs=[grid_mask_name, grid_token, grid_show_grid, grid_show_mask],
# outputs=grid_map,
# )
datasets_save_uploaded_tif.click(
gr_save_tif_file,
inputs=[datasets_upload_tif, datasets_upload_tif_name],
outputs=[datasets_tif_table, train_tif_names, infer_tif_names],
)
datasets_save_uploaded_mask.click(
gr_save_mask_file,
inputs=[
datasets_upload_mask,
datasets_upload_mask_name,
datasets_mask_class_name,
],
outputs=[datasets_mask_table, train_mask_names],
)
demo.queue(concurrency_count=10).launch(debug=True)