|
|
import streamlit as st |
|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
|
|
import h5py |
|
|
import torch |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import yaml |
|
|
import os |
|
|
import io |
|
|
|
|
|
|
|
|
from src.mobilenetv2_model import LandslideModel as MobileNetV2Model |
|
|
from src.vgg16_model import LandslideModel as VGG16Model |
|
|
from src.resnet34_model import LandslideModel as ResNet34Model |
|
|
from src.efficientnetb0_model import LandslideModel as EfficientNetB0Model |
|
|
from src.mitb1_model import LandslideModel as MiTB1Model |
|
|
from src.inceptionv4_model import LandslideModel as InceptionV4Model |
|
|
from src.densenet121_model import LandslideModel as DenseNet121Model |
|
|
from src.deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel |
|
|
from src.resnext50_32x4d_model import LandslideModel as ResNeXt50Model |
|
|
from src.se_resnet50_model import LandslideModel as SEResNet50Model |
|
|
from src.se_resnext50_32x4d_model import LandslideModel as SEResNeXt50Model |
|
|
from src.segformer_model import LandslideModel as SegFormerB2Model |
|
|
from src.inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model |
|
|
from src.model_downloader import ModelDownloader |
|
|
|
|
|
|
|
|
AVAILABLE_MODELS = { |
|
|
"mobilenetv2": {"name": "MobileNetV2", "type": "mobilenet_v2"}, |
|
|
"vgg16": {"name": "VGG16", "type": "vgg16"}, |
|
|
"resnet34": {"name": "ResNet34", "type": "resnet34"}, |
|
|
"efficientnetb0": {"name": "EfficientNetB0", "type": "efficientnet_b0"}, |
|
|
"mitb1": {"name": "MiTB1", "type": "mitb1"}, |
|
|
"inceptionv4": {"name": "InceptionV4", "type": "inception_v4"}, |
|
|
"densenet121": {"name": "DenseNet121", "type": "densenet121"}, |
|
|
"deeplabv3plus": {"name": "DeepLabV3Plus", "type": "deeplabv3plus"}, |
|
|
"resnext50": {"name": "ResNeXt50", "type": "resnext50_32x4d", "downloader_key": "resnext50_32x4d"}, |
|
|
"seresnet50": {"name": "SEResNet50", "type": "se_resnet50", "downloader_key": "se_resnet50"}, |
|
|
"seresnext50": {"name": "SEResNeXt50", "type": "se_resnext50_32x4d", "downloader_key": "se_resnext50_32x4d"}, |
|
|
"segformerb2": {"name": "SegFormerB2", "type": "segformer_b2", "downloader_key": "segformer"}, |
|
|
"inceptionresnetv2": {"name": "InceptionResNetV2", "type": "inception_resnet_v2"} |
|
|
} |
|
|
|
|
|
|
|
|
MODEL_DESCRIPTIONS = { |
|
|
model_key: { |
|
|
"type": model_info["type"], |
|
|
"description": f"{model_info['name']} - A model for landslide detection and segmentation.", |
|
|
"name": model_info["name"], |
|
|
"downloader_key": model_info.get("downloader_key", model_key) |
|
|
} |
|
|
for model_key, model_info in AVAILABLE_MODELS.items() |
|
|
} |
|
|
|
|
|
|
|
|
config_str = """ |
|
|
model_config: |
|
|
model_type: "mobilenet_v2" |
|
|
in_channels: 14 |
|
|
num_classes: 1 |
|
|
encoder_weights: "imagenet" |
|
|
wce_weight: 0.5 |
|
|
|
|
|
dataset_config: |
|
|
num_classes: 1 |
|
|
num_channels: 14 |
|
|
channels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] |
|
|
normalize: False |
|
|
|
|
|
train_config: |
|
|
dataset_path: "" |
|
|
checkpoint_path: "checkpoints" |
|
|
seed: 42 |
|
|
train_val_split: 0.8 |
|
|
batch_size: 16 |
|
|
num_epochs: 100 |
|
|
lr: 0.001 |
|
|
device: "cuda:0" |
|
|
save_config: True |
|
|
experiment_name: "mobilenet_v2" |
|
|
|
|
|
logging_config: |
|
|
wandb_project: "l4s" |
|
|
wandb_entity: "Silvamillion" |
|
|
""" |
|
|
|
|
|
config = yaml.safe_load(config_str) |
|
|
|
|
|
def process_and_visualize(model_key, model_info, image_tensor, original_image, uploaded_file_name): |
|
|
""" |
|
|
Process the image with the selected model and visualize results. |
|
|
""" |
|
|
try: |
|
|
st.write(f"Using model: {model_info['name']}") |
|
|
|
|
|
|
|
|
current_config = config.copy() |
|
|
current_config['model_config']['model_type'] = model_info['type'] |
|
|
|
|
|
|
|
|
model_class_name = AVAILABLE_MODELS[model_key]['name'].replace('-', '') + 'Model' |
|
|
if model_class_name not in globals(): |
|
|
|
|
|
|
|
|
pass |
|
|
model_class = globals()[model_class_name] |
|
|
|
|
|
|
|
|
downloader = ModelDownloader() |
|
|
|
|
|
|
|
|
download_key = model_info.get('downloader_key', model_key) |
|
|
model_path = downloader.download_model(download_key) |
|
|
st.info(f"Using model from: {model_path}") |
|
|
|
|
|
|
|
|
model = model_class(current_config) |
|
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
prediction = model(image_tensor) |
|
|
prediction = torch.sigmoid(prediction).cpu().numpy() |
|
|
|
|
|
|
|
|
st.header(f"Prediction Results - {model_info['name']}") |
|
|
fig, ax = plt.subplots(1, 3, figsize=(15, 5)) |
|
|
|
|
|
|
|
|
img_display = original_image.transpose(1, 2, 0) |
|
|
img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min()) |
|
|
|
|
|
ax[0].imshow(img_display[:, :, :3]) |
|
|
ax[0].set_title("Input Image") |
|
|
ax[0].axis('off') |
|
|
|
|
|
ax[1].imshow(prediction.squeeze(), cmap='plasma') |
|
|
ax[1].set_title("Prediction Probability") |
|
|
ax[1].axis('off') |
|
|
|
|
|
ax[2].imshow(img_display[:, :, :3]) |
|
|
ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.4) |
|
|
ax[2].set_title("Overlay (Threshold > 0.5)") |
|
|
ax[2].axis('off') |
|
|
|
|
|
st.pyplot(fig) |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
st.write(f"Download the prediction as a .npy file for {model_info['name']}:") |
|
|
npy_data = prediction.squeeze() |
|
|
st.download_button( |
|
|
label=f"Download Prediction - {model_info['name']}", |
|
|
data=npy_data.tobytes(), |
|
|
file_name=f"{uploaded_file_name.split('.')[0]}_{model_key}_prediction.npy", |
|
|
mime="application/octet-stream" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error with model {model_info['name']}: {str(e)}") |
|
|
import traceback |
|
|
st.error(traceback.format_exc()) |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="DeepSlide: Landslide Detection", layout="wide") |
|
|
|
|
|
st.title("DeepSlide: Landslide Detection") |
|
|
st.markdown(""" |
|
|
## Instructions |
|
|
1. **Model Selection**: Choose a single model from the sidebar or select "Run all models". |
|
|
2. **Data Input**: |
|
|
- Try an example image from the dropdown, or |
|
|
- Upload your own .h5 files |
|
|
3. **Results**: View predictions and download results as .npy files. |
|
|
""") |
|
|
|
|
|
|
|
|
st.sidebar.title("Model Selection") |
|
|
model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"]) |
|
|
|
|
|
selected_model_key = None |
|
|
if model_option == "Select a single model": |
|
|
selected_model_key = st.sidebar.selectbox("Select Model", list(MODEL_DESCRIPTIONS.keys())) |
|
|
selected_model_info = MODEL_DESCRIPTIONS[selected_model_key] |
|
|
|
|
|
|
|
|
st.sidebar.markdown("### Model Details") |
|
|
st.sidebar.markdown(f"**Model Name:** {selected_model_info['name']}") |
|
|
st.sidebar.markdown(f"**Model Type:** {selected_model_info['type']}") |
|
|
st.sidebar.markdown(f"**Description:** {selected_model_info['description']}") |
|
|
|
|
|
|
|
|
st.header("Upload Data") |
|
|
|
|
|
|
|
|
if 'upload_errors' not in st.session_state: |
|
|
st.session_state.upload_errors = [] |
|
|
|
|
|
|
|
|
st.subheader("Try Example Images") |
|
|
examples_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "examples") |
|
|
example_files = [] |
|
|
|
|
|
try: |
|
|
if os.path.exists(examples_dir): |
|
|
example_files = [f for f in os.listdir(examples_dir) if f.endswith('.h5')] |
|
|
example_files.sort() |
|
|
except: |
|
|
pass |
|
|
|
|
|
if example_files: |
|
|
selected_example = st.selectbox( |
|
|
"Select an example image to test:", |
|
|
options=["None"] + example_files, |
|
|
help="Choose an example .h5 file to quickly test the models" |
|
|
) |
|
|
else: |
|
|
st.info("No example files found") |
|
|
selected_example = "None" |
|
|
|
|
|
|
|
|
st.subheader("Upload Your Own Files") |
|
|
uploaded_files = st.file_uploader( |
|
|
"Choose .h5 files...", |
|
|
type="h5", |
|
|
accept_multiple_files=True, |
|
|
help="Upload your .h5 files here. Maximum file size is 200MB." |
|
|
) |
|
|
|
|
|
def process_h5_file(file_path, file_name): |
|
|
"""Process a single h5 file""" |
|
|
try: |
|
|
with h5py.File(file_path, 'r') as hdf: |
|
|
if 'img' not in hdf: |
|
|
st.error(f"Error: 'img' dataset not found in {file_name}") |
|
|
return |
|
|
|
|
|
data = np.array(hdf.get('img')) |
|
|
data[np.isnan(data)] = 0.000001 |
|
|
channels = config["dataset_config"]["channels"] |
|
|
|
|
|
image = np.zeros((128, 128, len(channels))) |
|
|
|
|
|
if data.ndim == 3: |
|
|
if data.shape[0] == 14: |
|
|
for i, band in enumerate(channels): |
|
|
image[:, :, i] = data[band-1, :, :] |
|
|
elif data.shape[2] == 14: |
|
|
for i, band in enumerate(channels): |
|
|
image[:, :, i] = data[:, :, band-1] |
|
|
else: |
|
|
st.warning(f"Unexpected data shape: {data.shape}. Assuming (C, H, W).") |
|
|
for i, band in enumerate(channels): |
|
|
if band-1 < data.shape[0]: |
|
|
image[:, :, i] = data[band-1, :, :] |
|
|
else: |
|
|
st.error(f"Data has {data.ndim} dimensions, expected 3.") |
|
|
return |
|
|
|
|
|
|
|
|
image_display = image.transpose(2, 0, 1) |
|
|
image_tensor = torch.from_numpy(image_display).unsqueeze(0).float() |
|
|
|
|
|
if model_option == "Select a single model": |
|
|
process_and_visualize(selected_model_key, selected_model_info, image_tensor, image_display, file_name) |
|
|
else: |
|
|
for model_key, model_info in MODEL_DESCRIPTIONS.items(): |
|
|
process_and_visualize(model_key, model_info, image_tensor, image_display, file_name) |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error processing file {file_name}: {str(e)}") |
|
|
|
|
|
|
|
|
if selected_example != "None": |
|
|
st.write(f"Processing example: {selected_example}") |
|
|
example_path = os.path.join(examples_dir, selected_example) |
|
|
with st.spinner(f'Processing {selected_example}...'): |
|
|
process_h5_file(example_path, selected_example) |
|
|
|
|
|
|
|
|
if uploaded_files: |
|
|
for uploaded_file in uploaded_files: |
|
|
st.write(f"Processing file: {uploaded_file.name}") |
|
|
st.write(f"File size: {uploaded_file.size} bytes") |
|
|
|
|
|
with st.spinner('Processing...'): |
|
|
try: |
|
|
|
|
|
bytes_data = uploaded_file.getvalue() |
|
|
bytes_io = io.BytesIO(bytes_data) |
|
|
|
|
|
with h5py.File(bytes_io, 'r') as hdf: |
|
|
if 'img' not in hdf: |
|
|
st.error(f"Error: 'img' dataset not found in {uploaded_file.name}") |
|
|
continue |
|
|
|
|
|
data = np.array(hdf.get('img')) |
|
|
data[np.isnan(data)] = 0.000001 |
|
|
channels = config["dataset_config"]["channels"] |
|
|
|
|
|
image = np.zeros((128, 128, len(channels))) |
|
|
|
|
|
if data.ndim == 3: |
|
|
if data.shape[0] == 14: |
|
|
for i, band in enumerate(channels): |
|
|
image[:, :, i] = data[band-1, :, :] |
|
|
elif data.shape[2] == 14: |
|
|
for i, band in enumerate(channels): |
|
|
image[:, :, i] = data[:, :, band-1] |
|
|
else: |
|
|
st.warning(f"Unexpected data shape: {data.shape}. Assuming (C, H, W).") |
|
|
for i, band in enumerate(channels): |
|
|
if band-1 < data.shape[0]: |
|
|
image[:, :, i] = data[band-1, :, :] |
|
|
else: |
|
|
st.error(f"Data has {data.ndim} dimensions, expected 3.") |
|
|
continue |
|
|
|
|
|
|
|
|
image_display = image.transpose(2, 0, 1) |
|
|
image_tensor = torch.from_numpy(image_display).unsqueeze(0).float() |
|
|
|
|
|
if model_option == "Select a single model": |
|
|
process_and_visualize(selected_model_key, selected_model_info, image_tensor, image_display, uploaded_file.name) |
|
|
else: |
|
|
for model_key, model_info in MODEL_DESCRIPTIONS.items(): |
|
|
process_and_visualize(model_key, model_info, image_tensor, image_display, uploaded_file.name) |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error processing file {uploaded_file.name}: {str(e)}") |
|
|
import traceback |
|
|
st.error(traceback.format_exc()) |
|
|
continue |
|
|
|
|
|
if selected_example != "None" or uploaded_files: |
|
|
st.success('✅ Processing completed!') |