DeepSlide / src /streamlit_app.py
harshinde's picture
Update src/streamlit_app.py
20acb0b verified
import streamlit as st
import sys
import os
# Add the parent directory to sys.path to allow imports from 'src'
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import h5py
import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml
import os
import io
# Import models
from src.mobilenetv2_model import LandslideModel as MobileNetV2Model
from src.vgg16_model import LandslideModel as VGG16Model
from src.resnet34_model import LandslideModel as ResNet34Model
from src.efficientnetb0_model import LandslideModel as EfficientNetB0Model
from src.mitb1_model import LandslideModel as MiTB1Model
from src.inceptionv4_model import LandslideModel as InceptionV4Model
from src.densenet121_model import LandslideModel as DenseNet121Model
from src.deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel
from src.resnext50_32x4d_model import LandslideModel as ResNeXt50Model
from src.se_resnet50_model import LandslideModel as SEResNet50Model
from src.se_resnext50_32x4d_model import LandslideModel as SEResNeXt50Model
from src.segformer_model import LandslideModel as SegFormerB2Model
from src.inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model
from src.model_downloader import ModelDownloader
# Define available models
AVAILABLE_MODELS = {
"mobilenetv2": {"name": "MobileNetV2", "type": "mobilenet_v2"},
"vgg16": {"name": "VGG16", "type": "vgg16"},
"resnet34": {"name": "ResNet34", "type": "resnet34"},
"efficientnetb0": {"name": "EfficientNetB0", "type": "efficientnet_b0"},
"mitb1": {"name": "MiTB1", "type": "mitb1"},
"inceptionv4": {"name": "InceptionV4", "type": "inception_v4"},
"densenet121": {"name": "DenseNet121", "type": "densenet121"},
"deeplabv3plus": {"name": "DeepLabV3Plus", "type": "deeplabv3plus"},
"resnext50": {"name": "ResNeXt50", "type": "resnext50_32x4d", "downloader_key": "resnext50_32x4d"},
"seresnet50": {"name": "SEResNet50", "type": "se_resnet50", "downloader_key": "se_resnet50"},
"seresnext50": {"name": "SEResNeXt50", "type": "se_resnext50_32x4d", "downloader_key": "se_resnext50_32x4d"},
"segformerb2": {"name": "SegFormerB2", "type": "segformer_b2", "downloader_key": "segformer"},
"inceptionresnetv2": {"name": "InceptionResNetV2", "type": "inception_resnet_v2"}
}
# Model descriptions with their respective types and descriptions
MODEL_DESCRIPTIONS = {
model_key: {
"type": model_info["type"],
"description": f"{model_info['name']} - A model for landslide detection and segmentation.",
"name": model_info["name"],
"downloader_key": model_info.get("downloader_key", model_key)
}
for model_key, model_info in AVAILABLE_MODELS.items()
}
# Load the configuration file
config_str = """
model_config:
model_type: "mobilenet_v2"
in_channels: 14
num_classes: 1
encoder_weights: "imagenet"
wce_weight: 0.5
dataset_config:
num_classes: 1
num_channels: 14
channels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
normalize: False
train_config:
dataset_path: ""
checkpoint_path: "checkpoints"
seed: 42
train_val_split: 0.8
batch_size: 16
num_epochs: 100
lr: 0.001
device: "cuda:0"
save_config: True
experiment_name: "mobilenet_v2"
logging_config:
wandb_project: "l4s"
wandb_entity: "Silvamillion"
"""
config = yaml.safe_load(config_str)
def process_and_visualize(model_key, model_info, image_tensor, original_image, uploaded_file_name):
"""
Process the image with the selected model and visualize results.
"""
try:
st.write(f"Using model: {model_info['name']}")
# Update config for the specific model
current_config = config.copy()
current_config['model_config']['model_type'] = model_info['type']
# Get the model class
model_class_name = AVAILABLE_MODELS[model_key]['name'].replace('-', '') + 'Model'
if model_class_name not in globals():
# Fallback for naming inconsistencies if any
# Try to find it in globals
pass
model_class = globals()[model_class_name]
# Initialize model downloader
downloader = ModelDownloader()
# Download/get model path
download_key = model_info.get('downloader_key', model_key)
model_path = downloader.download_model(download_key)
st.info(f"Using model from: {model_path}")
# Load the model
model = model_class(current_config)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
model.eval()
# Make prediction
with torch.no_grad():
prediction = model(image_tensor)
prediction = torch.sigmoid(prediction).cpu().numpy()
# Display prediction
st.header(f"Prediction Results - {model_info['name']}")
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
# Normalize image for display
img_display = original_image.transpose(1, 2, 0) # (C, H, W) -> (H, W, C)
img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min())
ax[0].imshow(img_display[:, :, :3]) # Display first three channels as RGB
ax[0].set_title("Input Image")
ax[0].axis('off')
ax[1].imshow(prediction.squeeze(), cmap='plasma') # Raw prediction map
ax[1].set_title("Prediction Probability")
ax[1].axis('off')
ax[2].imshow(img_display[:, :, :3])
ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.4) # Overlay
ax[2].set_title("Overlay (Threshold > 0.5)")
ax[2].axis('off')
st.pyplot(fig)
plt.close(fig)
# Download button
st.write(f"Download the prediction as a .npy file for {model_info['name']}:")
npy_data = prediction.squeeze()
st.download_button(
label=f"Download Prediction - {model_info['name']}",
data=npy_data.tobytes(),
file_name=f"{uploaded_file_name.split('.')[0]}_{model_key}_prediction.npy",
mime="application/octet-stream"
)
except Exception as e:
st.error(f"Error with model {model_info['name']}: {str(e)}")
import traceback
st.error(traceback.format_exc())
# Streamlit app
st.set_page_config(page_title="DeepSlide: Landslide Detection", layout="wide")
st.title("DeepSlide: Landslide Detection")
st.markdown("""
## Instructions
1. **Model Selection**: Choose a single model from the sidebar or select "Run all models".
2. **Data Input**:
- Try an example image from the dropdown, or
- Upload your own .h5 files
3. **Results**: View predictions and download results as .npy files.
""")
# Sidebar for model selection
st.sidebar.title("Model Selection")
model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"])
selected_model_key = None
if model_option == "Select a single model":
selected_model_key = st.sidebar.selectbox("Select Model", list(MODEL_DESCRIPTIONS.keys()))
selected_model_info = MODEL_DESCRIPTIONS[selected_model_key]
# Display model details in the sidebar
st.sidebar.markdown("### Model Details")
st.sidebar.markdown(f"**Model Name:** {selected_model_info['name']}")
st.sidebar.markdown(f"**Model Type:** {selected_model_info['type']}")
st.sidebar.markdown(f"**Description:** {selected_model_info['description']}")
# Main content
st.header("Upload Data")
# Initialize session state for error tracking if not exists
if 'upload_errors' not in st.session_state:
st.session_state.upload_errors = []
# Example images selection
st.subheader("Try Example Images")
examples_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "examples")
example_files = []
try:
if os.path.exists(examples_dir):
example_files = [f for f in os.listdir(examples_dir) if f.endswith('.h5')]
example_files.sort()
except:
pass
if example_files:
selected_example = st.selectbox(
"Select an example image to test:",
options=["None"] + example_files,
help="Choose an example .h5 file to quickly test the models"
)
else:
st.info("No example files found")
selected_example = "None"
# File upload section
st.subheader("Upload Your Own Files")
uploaded_files = st.file_uploader(
"Choose .h5 files...",
type="h5",
accept_multiple_files=True,
help="Upload your .h5 files here. Maximum file size is 200MB."
)
def process_h5_file(file_path, file_name):
"""Process a single h5 file"""
try:
with h5py.File(file_path, 'r') as hdf:
if 'img' not in hdf:
st.error(f"Error: 'img' dataset not found in {file_name}")
return
data = np.array(hdf.get('img'))
data[np.isnan(data)] = 0.000001
channels = config["dataset_config"]["channels"]
image = np.zeros((128, 128, len(channels)))
if data.ndim == 3:
if data.shape[0] == 14: # (C, H, W)
for i, band in enumerate(channels):
image[:, :, i] = data[band-1, :, :]
elif data.shape[2] == 14: # (H, W, C)
for i, band in enumerate(channels):
image[:, :, i] = data[:, :, band-1]
else:
st.warning(f"Unexpected data shape: {data.shape}. Assuming (C, H, W).")
for i, band in enumerate(channels):
if band-1 < data.shape[0]:
image[:, :, i] = data[band-1, :, :]
else:
st.error(f"Data has {data.ndim} dimensions, expected 3.")
return
# Prepare for model (Batch, Channel, Height, Width)
image_display = image.transpose(2, 0, 1) # (C, H, W)
image_tensor = torch.from_numpy(image_display).unsqueeze(0).float() # (1, C, H, W)
if model_option == "Select a single model":
process_and_visualize(selected_model_key, selected_model_info, image_tensor, image_display, file_name)
else:
for model_key, model_info in MODEL_DESCRIPTIONS.items():
process_and_visualize(model_key, model_info, image_tensor, image_display, file_name)
except Exception as e:
st.error(f"Error processing file {file_name}: {str(e)}")
# Process example file if selected
if selected_example != "None":
st.write(f"Processing example: {selected_example}")
example_path = os.path.join(examples_dir, selected_example)
with st.spinner(f'Processing {selected_example}...'):
process_h5_file(example_path, selected_example)
# Process uploaded files
if uploaded_files:
for uploaded_file in uploaded_files:
st.write(f"Processing file: {uploaded_file.name}")
st.write(f"File size: {uploaded_file.size} bytes")
with st.spinner('Processing...'):
try:
# Read the file directly using BytesIO
bytes_data = uploaded_file.getvalue()
bytes_io = io.BytesIO(bytes_data)
with h5py.File(bytes_io, 'r') as hdf:
if 'img' not in hdf:
st.error(f"Error: 'img' dataset not found in {uploaded_file.name}")
continue
data = np.array(hdf.get('img'))
data[np.isnan(data)] = 0.000001
channels = config["dataset_config"]["channels"]
image = np.zeros((128, 128, len(channels)))
if data.ndim == 3:
if data.shape[0] == 14: # (C, H, W)
for i, band in enumerate(channels):
image[:, :, i] = data[band-1, :, :]
elif data.shape[2] == 14: # (H, W, C)
for i, band in enumerate(channels):
image[:, :, i] = data[:, :, band-1]
else:
st.warning(f"Unexpected data shape: {data.shape}. Assuming (C, H, W).")
for i, band in enumerate(channels):
if band-1 < data.shape[0]:
image[:, :, i] = data[band-1, :, :]
else:
st.error(f"Data has {data.ndim} dimensions, expected 3.")
continue
# Prepare for model (Batch, Channel, Height, Width)
image_display = image.transpose(2, 0, 1) # (C, H, W)
image_tensor = torch.from_numpy(image_display).unsqueeze(0).float() # (1, C, H, W)
if model_option == "Select a single model":
process_and_visualize(selected_model_key, selected_model_info, image_tensor, image_display, uploaded_file.name)
else:
for model_key, model_info in MODEL_DESCRIPTIONS.items():
process_and_visualize(model_key, model_info, image_tensor, image_display, uploaded_file.name)
except Exception as e:
st.error(f"Error processing file {uploaded_file.name}: {str(e)}")
import traceback
st.error(traceback.format_exc())
continue
if selected_example != "None" or uploaded_files:
st.success('✅ Processing completed!')