| import streamlit as st |
| import tensorflow as tf |
| from tensorflow.keras.models import Model |
| from tensorflow.keras.layers import * |
| from tensorflow.keras.optimizers import Adam |
| import cv2 |
| import os |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from PIL import Image |
| import io |
| import base64 |
| import tempfile |
| import zipfile |
| import random |
| import time |
| import rasterio |
| from rasterio.errors import RasterioIOError |
| import h5py |
| import json |
|
|
| |
| st.set_page_config( |
| page_title="SAR Image Colorization", |
| page_icon="🛰", |
| layout="wide" |
| ) |
|
|
|
|
| def display_image(image_path): |
| """Display an image with proper handling for different formats""" |
| try: |
| if os.path.exists(image_path): |
| if image_path.lower().endswith(('.tif', '.tiff')): |
| |
| try: |
| with rasterio.open(image_path) as src: |
| img_data = src.read(1) |
| |
| |
| if src.count > 1: |
| |
| if src.count >= 3: |
| img_data = np.dstack([src.read(i) for i in range(1, 4)]) |
| else: |
| |
| img_data = np.dstack([src.read(1), src.read(2), src.read(2)]) |
| else: |
| |
| img_data = np.dstack([img_data, img_data, img_data]) |
| |
| |
| if img_data.dtype != np.uint8: |
| img_data = (img_data - np.min(img_data)) / (np.max(img_data) - np.min(img_data)) * 255 |
| img_data = img_data.astype(np.uint8) |
| |
| st.image(img_data, use_container_width=True) |
| except Exception as rasterio_error: |
| |
| try: |
| img = Image.open(image_path) |
| st.image(img, use_container_width=True) |
| except Exception as pil_error: |
| st.error(f"Failed to load image: {str(pil_error)}") |
| else: |
| |
| img = Image.open(image_path) |
| st.image(img, use_container_width=True) |
| else: |
| st.info(f"Image file not found: {image_path}") |
| except Exception as e: |
| st.error(f"Error loading image: {str(e)}") |
|
|
| |
|
|
| |
| @st.cache_resource |
| def setup_gpu(): |
| gpus = tf.config.experimental.list_physical_devices('GPU') |
| if gpus: |
| for gpu in gpus: |
| tf.config.experimental.set_memory_growth(gpu, True) |
| return f"GPU setup complete. Found {len(gpus)} GPU(s)." |
| return "No GPUs found. Running on CPU." |
|
|
| |
| def get_esa_colors(): |
| return { |
| 0: [0, 100, 0], |
| 1: [255, 165, 0], |
| 2: [144, 238, 144], |
| 3: [255, 255, 0], |
| 4: [255, 0, 0], |
| 5: [139, 69, 19], |
| 6: [255, 255, 255], |
| 7: [0, 0, 255], |
| 8: [0, 139, 139], |
| 9: [0, 255, 0], |
| 10: [220, 220, 220] |
| } |
|
|
| |
| def visualize_with_ground_truth(sar_image, ground_truth, prediction): |
| """Visualize SAR image with ground truth and prediction using ESA WorldCover colors""" |
| |
| colors = get_esa_colors() |
| |
| |
| pred_class = np.argmax(prediction[0], axis=-1) |
| colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) |
| |
| for class_idx, color in colors.items(): |
| colored_pred[pred_class == class_idx] = color |
| |
| |
| gt_class = ground_truth[:,:,0].astype(np.int32) |
| |
| |
| if np.max(gt_class) > 10: |
| |
| gt_mapped = np.zeros_like(gt_class) |
| class_values = sorted(st.session_state.segmentation.class_definitions.values()) |
| for i, val in enumerate(class_values): |
| gt_mapped[gt_class == val] = i |
| gt_class = gt_mapped |
| |
| colored_gt = np.zeros((gt_class.shape[0], gt_class.shape[1], 3), dtype=np.uint8) |
| |
| for class_idx, color in colors.items(): |
| colored_gt[gt_class == class_idx] = color |
| |
| |
| sar_rgb = np.repeat(sar_image[:, :, 0:1], 3, axis=2) |
| |
| sar_rgb = ((sar_rgb + 1) / 2 * 255).astype(np.uint8) |
| |
| overlay = cv2.addWeighted( |
| sar_rgb, |
| 0.7, |
| colored_pred, |
| 0.3, |
| 0 |
| ) |
| |
| |
| bg_color = '#0a0a1f' if st.session_state.theme == 'dark' else '#ffffff' |
| text_color = 'white' if st.session_state.theme == 'dark' else 'black' |
| |
| |
| fig, axes = plt.subplots(1, 4, figsize=(16, 4)) |
| |
| |
| axes[0].imshow(sar_rgb, cmap='gray') |
| axes[0].set_title('Original SAR', color=text_color) |
| axes[0].axis('off') |
| |
| |
| axes[1].imshow(colored_gt) |
| axes[1].set_title('Ground Truth', color=text_color) |
| axes[1].axis('off') |
| |
| |
| axes[2].imshow(colored_pred) |
| axes[2].set_title('Prediction', color=text_color) |
| axes[2].axis('off') |
| |
| |
| axes[3].imshow(overlay) |
| axes[3].set_title('Colorized Output', color=text_color) |
| axes[3].axis('off') |
| |
| |
| fig.patch.set_facecolor(bg_color) |
| for ax in axes: |
| ax.set_facecolor(bg_color) |
| |
| plt.tight_layout() |
| |
| |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', facecolor=bg_color, bbox_inches='tight') |
| buf.seek(0) |
| plt.close(fig) |
| |
| return buf, colored_gt, colored_pred, overlay |
|
|
|
|
| |
| @st.cache_resource |
| def load_models(unet_weights_path, generator_path=None): |
| |
| unet = get_unet(input_shape=(256, 256, 1), classes=11) |
| unet.load_weights(unet_weights_path) |
| |
| |
| generator = None |
| if generator_path: |
| try: |
| generator = tf.keras.models.load_model(generator_path) |
| except Exception as e: |
| st.error(f"Error loading generator model: {e}") |
| |
| return unet, generator |
|
|
| |
| def preprocess_sar_for_optical(sar_data): |
| """Preprocess SAR data""" |
| |
| sar_clipped = np.clip(sar_data, -50, 20) |
| sar_normalized = (sar_clipped - np.min(sar_clipped)) / (np.max(sar_clipped) - np.min(sar_clipped)) * 2 - 1 |
| return sar_normalized |
|
|
| |
| def load_sar_image(file, img_size=(256, 256)): |
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file: |
| tmp_file.write(file.getbuffer()) |
| tmp_file_path = tmp_file.name |
| |
| try: |
| with rasterio.open(tmp_file_path) as src: |
| image = src.read(1) |
| image = cv2.resize(image, img_size) |
| image = np.expand_dims(image, axis=-1) |
| |
| |
| image = preprocess_sar_for_optical(image) |
| return np.expand_dims(image, axis=0), image |
| except Exception as e: |
| st.error(f"Error loading SAR image: {e}") |
| return None, None |
| finally: |
| |
| os.unlink(tmp_file_path) |
|
|
| |
| def process_image(sar_image, unet_model, generator_model=None): |
| |
| seg_mask = unet_model.predict(sar_image) |
| |
| |
| colorized = None |
| if generator_model: |
| colorized = generator_model.predict([sar_image, seg_mask]) |
| colorized = colorized[0] |
|
|
| return seg_mask[0], colorized |
|
|
| |
| def visualize_results(sar_image, seg_mask, colorized=None): |
| |
| colors = get_esa_colors() |
| |
| |
| pred_class = np.argmax(seg_mask, axis=-1) |
| colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) |
| |
| for class_idx, color in colors.items(): |
| colored_pred[pred_class == class_idx] = color |
| |
| |
| sar_rgb = np.repeat(sar_image[:, :, 0:1], 3, axis=2) |
| |
| sar_rgb = ((sar_rgb + 1) / 2 * 255).astype(np.uint8) |
| |
| overlay = cv2.addWeighted( |
| sar_rgb, |
| 0.7, |
| colored_pred, |
| 0.3, |
| 0 |
| ) |
| |
| return sar_rgb, colored_pred, overlay, colorized |
|
|
|
|
| |
| def load_model_with_weights(model_path): |
| """Load a model directly from an H5 file, preserving the original architecture""" |
| |
| if not os.path.dirname(model_path) and not model_path.startswith('models/'): |
| model_path = os.path.join('models', os.path.basename(model_path)) |
| |
| try: |
| |
| |
| import tensorflow as tf |
| keras_version = tf.keras.__version__[0] |
| |
| if keras_version == '3': |
| |
| custom_objects = { |
| 'BilinearUpsampling': BilinearUpsampling |
| } |
| model = tf.keras.models.load_model(model_path, compile=False, custom_objects=custom_objects) |
| else: |
| |
| model = tf.keras.models.load_model(model_path, compile=False) |
| |
| print("Loaded complete model with architecture") |
| return model |
| except Exception as e: |
| print(f"Could not load complete model: {str(e)}") |
| print("Attempting to load just the weights into a matching architecture...") |
| |
| |
| try: |
| with h5py.File(model_path, 'r') as f: |
| model_config = None |
| if 'model_config' in f.attrs: |
| model_config = json.loads(f.attrs['model_config'].decode('utf-8')) |
| |
| |
| if model_config: |
| try: |
| model = tf.keras.models.model_from_json(json.dumps(model_config)) |
| model.load_weights(model_path) |
| print("Successfully loaded model from config and weights") |
| return model |
| except Exception as e2: |
| print(f"Failed to load from config: {str(e2)}") |
| except Exception as e3: |
| print(f"Failed to inspect model file: {str(e3)}") |
| |
| |
| try: |
| |
| if st.session_state.segmentation.model_type == 'unet': |
| model = get_unet( |
| input_shape=(256, 256, 1), |
| drop_rate=0.3, |
| classes=11 |
| ) |
| elif st.session_state.segmentation.model_type == 'deeplabv3plus': |
| model = DeepLabV3Plus( |
| input_shape=(256, 256, 1), |
| classes=11 |
| ) |
| elif st.session_state.segmentation.model_type == 'segnet': |
| model = SegNet( |
| input_shape=(256, 256, 1), |
| classes=11 |
| ) |
| |
| |
| model.load_weights(model_path, by_name=True, skip_mismatch=True) |
| print("Created new model and loaded compatible weights") |
| return model |
| except Exception as e4: |
| print(f"Failed to create new model and load weights: {str(e4)}") |
| |
| |
| return None |
|
|
| |
| def create_legend(): |
| """Create a legend for the land cover classes""" |
| colors = { |
| 'Trees': [0, 100, 0], |
| 'Shrubland': [255, 165, 0], |
| 'Grassland': [144, 238, 144], |
| 'Cropland': [255, 255, 0], |
| 'Built-up': [255, 0, 0], |
| 'Bare': [139, 69, 19], |
| 'Snow': [255, 255, 255], |
| 'Water': [0, 0, 255], |
| 'Wetland': [0, 139, 139], |
| 'Mangroves': [0, 255, 0], |
| 'Moss': [220, 220, 220] |
| } |
| |
| |
| bg_color = '#0a0a1f' if st.session_state.theme == 'dark' else '#ffffff' |
| text_color = 'white' if st.session_state.theme == 'dark' else 'black' |
| |
| fig, ax = plt.subplots(figsize=(8, 4)) |
| fig.patch.set_facecolor(bg_color) |
| ax.set_facecolor(bg_color) |
| |
| |
| for i, (class_name, color) in enumerate(colors.items()): |
| ax.add_patch(plt.Rectangle((0, i), 0.5, 0.8, color=[c/255 for c in color])) |
| ax.text(0.7, i + 0.4, class_name, color=text_color, fontsize=12) |
| |
| ax.set_xlim(0, 3) |
| ax.set_ylim(-0.5, len(colors) - 0.5) |
| ax.set_title('Land Cover Classes', color=text_color, fontsize=14) |
| ax.axis('off') |
| |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', facecolor=bg_color, bbox_inches='tight') |
| buf.seek(0) |
| plt.close(fig) |
| |
| return buf |
|
|
|
|
| |
| |
| def visualize_prediction(prediction, original_sar, figsize=(10, 4)): |
| """Visualize segmentation prediction with ESA WorldCover colors""" |
| |
| colors = get_esa_colors() |
| |
| |
| pred_class = np.argmax(prediction[0], axis=-1) |
| colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) |
| |
| for class_idx, color in colors.items(): |
| colored_pred[pred_class == class_idx] = color |
| |
| |
| sar_rgb = cv2.cvtColor(original_sar[:,:,0], cv2.COLOR_GRAY2RGB) |
| overlay = cv2.addWeighted(sar_rgb, 0.7, colored_pred, 0.3, 0) |
| |
| |
| fig, axes = plt.subplots(1, 3, figsize=figsize) |
| |
| |
| bg_color = '#0a0a1f' if st.session_state.theme == 'dark' else '#ffffff' |
| text_color = 'white' if st.session_state.theme == 'dark' else 'black' |
| |
| |
| axes[0].imshow(original_sar[:,:,0], cmap='gray') |
| axes[0].set_title('Original SAR', color=text_color) |
| axes[0].axis('off') |
| |
| |
| axes[1].imshow(colored_pred) |
| axes[1].set_title('Prediction', color=text_color) |
| axes[1].axis('off') |
| |
| |
| axes[2].imshow(overlay) |
| axes[2].set_title('Colorized Output', color=text_color) |
| axes[2].axis('off') |
| |
| |
| fig.patch.set_facecolor(bg_color) |
| for ax in axes: |
| ax.set_facecolor(bg_color) |
| |
| plt.tight_layout() |
| |
| |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', facecolor=bg_color, bbox_inches='tight') |
| buf.seek(0) |
| plt.close(fig) |
| return buf |
|
|
| |
|
|
| |
| def get_unet(input_shape=(256, 256, 1), drop_rate=0.3, classes=11): |
| inputs = Input(input_shape) |
| |
| |
| conv1_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs) |
| batch1_1 = BatchNormalization()(conv1_1) |
| conv1_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch1_1) |
| batch1_2 = BatchNormalization()(conv1_2) |
| pool1 = MaxPooling2D(pool_size=(2, 2))(batch1_2) |
|
|
| conv2_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1) |
| batch2_1 = BatchNormalization()(conv2_1) |
| conv2_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch2_1) |
| batch2_2 = BatchNormalization()(conv2_2) |
| pool2 = MaxPooling2D(pool_size=(2, 2))(batch2_2) |
|
|
| conv3_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2) |
| batch3_1 = BatchNormalization()(conv3_1) |
| conv3_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch3_1) |
| batch3_2 = BatchNormalization()(conv3_2) |
| pool3 = MaxPooling2D(pool_size=(2, 2))(batch3_2) |
|
|
| conv4_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3) |
| batch4_1 = BatchNormalization()(conv4_1) |
| conv4_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch4_1) |
| batch4_2 = BatchNormalization()(conv4_2) |
| drop4 = Dropout(drop_rate)(batch4_2) |
| pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) |
|
|
| |
| conv5_1 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4) |
| batch5_1 = BatchNormalization()(conv5_1) |
| conv5_2 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch5_1) |
| batch5_2 = BatchNormalization()(conv5_2) |
| drop5 = Dropout(drop_rate)(batch5_2) |
|
|
| |
| up6 = Conv2D(512, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5)) |
| merge6 = concatenate([drop4, up6]) |
| conv6_1 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6) |
| batch6_1 = BatchNormalization()(conv6_1) |
| conv6_2 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch6_1) |
| batch6_2 = BatchNormalization()(conv6_2) |
|
|
| up7 = Conv2D(256, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch6_2)) |
| merge7 = concatenate([batch3_2, up7]) |
| conv7_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7) |
| batch7_1 = BatchNormalization()(conv7_1) |
| conv7_2 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch7_1) |
| batch7_2 = BatchNormalization()(conv7_2) |
|
|
| up8 = Conv2D(128, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch7_2)) |
| merge8 = concatenate([batch2_2, up8]) |
| conv8_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8) |
| batch8_1 = BatchNormalization()(conv8_1) |
| conv8_2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch8_1) |
| batch8_2 = BatchNormalization()(conv8_2) |
|
|
| up9 = Conv2D(64, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(batch8_2)) |
| merge9 = concatenate([batch1_2, up9]) |
| conv9_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9) |
| batch9_1 = BatchNormalization()(conv9_1) |
| conv9_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(batch9_1) |
| batch9_2 = BatchNormalization()(conv9_2) |
|
|
| outputs = Conv2D(classes, 1, activation='softmax')(batch9_2) |
|
|
| model = Model(inputs=inputs, outputs=outputs) |
| model.compile(optimizer=Adam(learning_rate=1e-4), |
| loss='categorical_crossentropy', |
| metrics=['accuracy']) |
| |
| return model |
|
|
| |
| class BilinearUpsampling(Layer): |
| def __init__(self, size=(1, 1), **kwargs): |
| super(BilinearUpsampling, self).__init__(**kwargs) |
| self.size = size |
|
|
| def call(self, inputs): |
| return tf.image.resize(inputs, self.size, method='bilinear') |
| |
| def compute_output_shape(self, input_shape): |
| return (input_shape[0], self.size[0], self.size[1], input_shape[3]) |
| |
| def get_config(self): |
| config = super(BilinearUpsampling, self).get_config() |
| config.update({'size': self.size}) |
| return config |
|
|
| |
| def DeepLabV3Plus(input_shape=(256, 256, 1), classes=11, output_stride=16): |
| """ |
| DeepLabV3+ model with Xception backbone |
| |
| Args: |
| input_shape: Shape of input images |
| classes: Number of classes for segmentation |
| output_stride: Output stride for dilated convolutions (16 or 8) |
| |
| Returns: |
| model: DeepLabV3+ model |
| """ |
| |
| inputs = Input(input_shape) |
| |
| |
| if output_stride == 16: |
| atrous_rates = (6, 12, 18) |
| elif output_stride == 8: |
| atrous_rates = (12, 24, 36) |
| else: |
| raise ValueError("Output stride must be 8 or 16") |
| |
| |
| |
| x = Conv2D(32, 3, strides=(2, 2), padding='same', use_bias=False)(inputs) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| |
| x = Conv2D(64, 3, padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| |
| |
| |
| residual = Conv2D(128, 1, strides=(2, 2), padding='same', use_bias=False)(x) |
| residual = BatchNormalization()(residual) |
| |
| x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = SeparableConv2D(128, 3, padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) |
| x = Add()([x, residual]) |
| |
| |
| residual = Conv2D(256, 1, strides=(2, 2), padding='same', use_bias=False)(x) |
| residual = BatchNormalization()(residual) |
| |
| x = Activation('relu')(x) |
| x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = SeparableConv2D(256, 3, padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) |
| x = Add()([x, residual]) |
| |
| |
| low_level_features = x |
| |
| |
| residual = Conv2D(728, 1, strides=(2, 2), padding='same', use_bias=False)(x) |
| residual = BatchNormalization()(residual) |
| |
| x = Activation('relu')(x) |
| x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = MaxPooling2D(3, strides=(2, 2), padding='same')(x) |
| x = Add()([x, residual]) |
| |
| |
| for i in range(16): |
| residual = x |
| |
| x = Activation('relu')(x) |
| x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = SeparableConv2D(728, 3, padding='same', dilation_rate=2, use_bias=False)(x) |
| x = BatchNormalization()(x) |
| |
| x = Add()([x, residual]) |
| |
| |
| x = Activation('relu')(x) |
| x = SeparableConv2D(728, 3, padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = SeparableConv2D(1024, 3, padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| |
| |
| |
| aspp_out1 = Conv2D(256, 1, padding='same', use_bias=False)(x) |
| aspp_out1 = BatchNormalization()(aspp_out1) |
| aspp_out1 = Activation('relu')(aspp_out1) |
| |
| |
| aspp_out2 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[0], use_bias=False)(x) |
| aspp_out2 = BatchNormalization()(aspp_out2) |
| aspp_out2 = Activation('relu')(aspp_out2) |
| |
| aspp_out3 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[1], use_bias=False)(x) |
| aspp_out3 = BatchNormalization()(aspp_out3) |
| aspp_out3 = Activation('relu')(aspp_out3) |
| |
| aspp_out4 = Conv2D(256, 3, padding='same', dilation_rate=atrous_rates[2], use_bias=False)(x) |
| aspp_out4 = BatchNormalization()(aspp_out4) |
| aspp_out4 = Activation('relu')(aspp_out4) |
| |
| |
| |
| aspp_out5 = GlobalAveragePooling2D()(x) |
| aspp_out5 = Reshape((1, 1, 1024))(aspp_out5) |
| aspp_out5 = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out5) |
| aspp_out5 = BatchNormalization()(aspp_out5) |
| aspp_out5 = Activation('relu')(aspp_out5) |
| |
| |
| _, height, width, _ = tf.keras.backend.int_shape(x) |
| aspp_out5 = UpSampling2D(size=(height, width), interpolation='bilinear')(aspp_out5) |
| |
| |
| aspp_out = Concatenate()([aspp_out1, aspp_out2, aspp_out3, aspp_out4, aspp_out5]) |
| |
| |
| aspp_out = Conv2D(256, 1, padding='same', use_bias=False)(aspp_out) |
| aspp_out = BatchNormalization()(aspp_out) |
| aspp_out = Activation('relu')(aspp_out) |
| |
| |
| |
| low_level_features = Conv2D(48, 1, padding='same', use_bias=False)(low_level_features) |
| low_level_features = BatchNormalization()(low_level_features) |
| low_level_features = Activation('relu')(low_level_features) |
| |
| |
| |
| low_level_shape = tf.keras.backend.int_shape(low_level_features) |
| |
| |
| x = UpSampling2D(size=(low_level_shape[1] // tf.keras.backend.int_shape(aspp_out)[1], |
| low_level_shape[2] // tf.keras.backend.int_shape(aspp_out)[2]), |
| interpolation='bilinear')(aspp_out) |
| |
| |
| x = Concatenate()([x, low_level_features]) |
| |
| |
| x = Conv2D(256, 3, padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| |
| x = Conv2D(256, 3, padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| |
| |
| x_shape = tf.keras.backend.int_shape(x) |
| upsampling_size = (input_shape[0] // x_shape[1], input_shape[1] // x_shape[2]) |
| |
| |
| x = UpSampling2D(size=upsampling_size, interpolation='bilinear')(x) |
| |
| |
| outputs = Conv2D(classes, 1, padding='same', activation='softmax')(x) |
| |
| model = Model(inputs=inputs, outputs=outputs) |
| model.compile(optimizer=Adam(learning_rate=1e-4), |
| loss='categorical_crossentropy', |
| metrics=['accuracy']) |
| |
| return model |
|
|
| |
| def SegNet(input_shape=(256, 256, 1), classes=11): |
| """ |
| SegNet model for semantic segmentation |
| |
| Args: |
| input_shape: Shape of input images |
| classes: Number of classes for segmentation |
| |
| Returns: |
| model: SegNet model |
| """ |
| |
| inputs = Input(input_shape) |
| |
| |
| |
| x = Conv2D(64, (3, 3), padding='same', use_bias=False)(inputs) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| |
| x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) |
| |
| |
| x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) |
| |
| |
| x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) |
| |
| |
| x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) |
| |
| |
| x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x) |
| |
| |
| |
| |
| |
| x = UpSampling2D(size=(2, 2))(x) |
| x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| |
| |
| x = UpSampling2D(size=(2, 2))(x) |
| x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(512, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| |
| |
| x = UpSampling2D(size=(2, 2))(x) |
| x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(256, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| |
| |
| x = UpSampling2D(size=(2, 2))(x) |
| x = Conv2D(128, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| |
| |
| x = UpSampling2D(size=(2, 2))(x) |
| x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| x = Conv2D(64, (3, 3), padding='same', use_bias=False)(x) |
| x = BatchNormalization()(x) |
| x = Activation('relu')(x) |
| |
| |
| outputs = Conv2D(classes, (1, 1), padding='same', activation='softmax')(x) |
| |
| model = Model(inputs=inputs, outputs=outputs) |
| |
| return model |
|
|
| |
|
|
| class SARSegmentation: |
| def __init__(self, img_rows=256, img_cols=256, drop_rate=0.5, model_type='unet'): |
| self.img_rows = img_rows |
| self.img_cols = img_cols |
| self.drop_rate = drop_rate |
| self.num_channels = 1 |
| self.model = None |
| self.model_type = model_type.lower() |
| |
| |
| self.class_definitions = { |
| 'trees': 10, |
| 'shrubland': 20, |
| 'grassland': 30, |
| 'cropland': 40, |
| 'built_up': 50, |
| 'bare': 60, |
| 'snow': 70, |
| 'water': 80, |
| 'wetland': 90, |
| 'mangroves': 95, |
| 'moss': 100 |
| } |
| self.num_classes = len(self.class_definitions) |
| |
| |
| self.class_colors = get_esa_colors() |
|
|
| def load_sar_data(self, file_path_or_bytes, is_bytes=False): |
| """Load SAR data from file path or bytes""" |
| try: |
| if is_bytes: |
| |
| with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp: |
| tmp.write(file_path_or_bytes) |
| tmp_path = tmp.name |
| |
| try: |
| with rasterio.open(tmp_path) as src: |
| sar_data = src.read(1) |
| sar_data = np.expand_dims(sar_data, axis=-1) |
| except Exception as e: |
| |
| img = Image.open(tmp_path).convert('L') |
| sar_data = np.array(img) |
| sar_data = np.expand_dims(sar_data, axis=-1) |
| |
| |
| os.unlink(tmp_path) |
| else: |
| try: |
| with rasterio.open(file_path_or_bytes) as src: |
| sar_data = src.read(1) |
| sar_data = np.expand_dims(sar_data, axis=-1) |
| except RasterioIOError: |
| |
| img = Image.open(file_path_or_bytes).convert('L') |
| sar_data = np.array(img) |
| sar_data = np.expand_dims(sar_data, axis=-1) |
| |
| |
| if sar_data.shape[:2] != (self.img_rows, self.img_cols): |
| sar_data = cv2.resize(sar_data, (self.img_cols, self.img_rows)) |
| sar_data = np.expand_dims(sar_data, axis=-1) |
| |
| return sar_data |
| except Exception as e: |
| raise ValueError(f"Failed to load SAR data: {str(e)}") |
|
|
| def preprocess_sar(self, sar_data): |
| """Preprocess SAR data""" |
| |
| if np.max(sar_data) <= 255 and np.min(sar_data) >= 0: |
| |
| sar_normalized = (sar_data / 127.5) - 1 |
| else: |
| |
| sar_clipped = np.clip(sar_data, -50, 20) |
| sar_normalized = (sar_clipped - np.min(sar_clipped)) / (np.max(sar_clipped) - np.min(sar_clipped)) * 2 - 1 |
| |
| return sar_normalized |
|
|
| def one_hot_encode(self, labels): |
| """Convert ESA WorldCover labels to one-hot encoded format""" |
| encoded = np.zeros((labels.shape[0], labels.shape[1], self.num_classes)) |
| |
| for i, value in enumerate(sorted(self.class_definitions.values())): |
| encoded[:, :, i] = (labels == value) |
| |
| return encoded |
|
|
| def load_trained_model(self, model_path): |
| """Load a trained model from file""" |
| try: |
| |
| if not os.path.dirname(model_path) and not model_path.startswith('models/'): |
| model_path = os.path.join('models', os.path.basename(model_path)) |
| |
| |
| self.model = load_model_with_weights(model_path) |
| |
| if self.model is not None: |
| has_dilated_convs = False |
| for layer in self.model.layers: |
| if 'conv' in layer.name.lower() and hasattr(layer, 'dilation_rate'): |
| if isinstance(layer.dilation_rate, (list, tuple)): |
| if any(rate > 1 for rate in layer.dilation_rate): |
| has_dilated_convs = True |
| break |
| elif layer.dilation_rate > 1: |
| has_dilated_convs = True |
| break |
| |
| if has_dilated_convs: |
| self.model_type = 'deeplabv3plus' |
| print("Detected DeepLabV3+ model") |
| |
| elif len([l for l in self.model.layers if isinstance(l, MaxPooling2D)]) >= 5: |
| self.model_type = 'segnet' |
| print("Detected SegNet model") |
| else: |
| self.model_type = 'unet' |
| print("Detected U-Net model") |
| |
| if self.model is None: |
| |
| if self.model_type == 'unet': |
| self.model = get_unet( |
| input_shape=(self.img_rows, self.img_cols, self.num_channels), |
| drop_rate=self.drop_rate, |
| classes=self.num_classes |
| ) |
| elif self.model_type == 'deeplabv3plus': |
| self.model = DeepLabV3Plus( |
| input_shape=(self.img_rows, self.img_cols, self.num_channels), |
| classes=self.num_classes |
| ) |
| elif self.model_type == 'segnet': |
| self.model = SegNet( |
| input_shape=(self.img_rows, self.img_cols, self.num_channels), |
| classes=self.num_classes |
| ) |
| else: |
| raise ValueError(f"Model type {self.model_type} not supported") |
| |
| |
| self.model.load_weights(model_path, by_name=True, skip_mismatch=True) |
| |
| |
| if not any(np.any(w) for w in self.model.get_weights()): |
| raise ValueError("No weights were loaded. The model architecture is incompatible.") |
| except Exception as e: |
| raise ValueError(f"Failed to load model: {str(e)}") |
|
|
| def predict(self, sar_data): |
| """Predict segmentation for new SAR data""" |
| if self.model is None: |
| raise ValueError("Model not trained. Call train() first or load a trained model.") |
| |
| |
| sar_processed = self.preprocess_sar(sar_data) |
| |
| |
| if len(sar_processed.shape) == 3: |
| sar_processed = np.expand_dims(sar_processed, axis=0) |
| |
| |
| prediction = self.model.predict(sar_processed) |
| return prediction |
|
|
| def get_colored_prediction(self, prediction): |
| """Convert prediction to colored image""" |
| pred_class = np.argmax(prediction[0], axis=-1) |
| colored_pred = np.zeros((pred_class.shape[0], pred_class.shape[1], 3), dtype=np.uint8) |
| |
| for class_idx, color in self.class_colors.items(): |
| colored_pred[pred_class == class_idx] = color |
| |
| return colored_pred, pred_class |
|
|
| |
|
|
| |
| |
| if 'app_mode' not in st.session_state: |
| st.session_state.app_mode = "SAR Colorization" |
| if 'model_loaded' not in st.session_state: |
| st.session_state.model_loaded = False |
| if 'segmentation' not in st.session_state: |
| st.session_state.segmentation = SARSegmentation(img_rows=256, img_cols=256) |
| if 'processed_images' not in st.session_state: |
| st.session_state.processed_images = [] |
| if 'theme' not in st.session_state: |
| st.session_state.theme = "dark" |
| |
| def set_app_style(app_mode): |
| if app_mode == "SAR Colorization": |
| |
| st.markdown( |
| """ |
| <style> |
| .stApp { |
| background-color: #0a0a1f; |
| color: white; |
| } |
| |
| .main { |
| background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80"); |
| background-size: cover; |
| background-position: center; |
| background-repeat: no-repeat; |
| background-attachment: fixed; |
| position: relative; |
| } |
| |
| .main::before { |
| content: ""; |
| position: absolute; |
| top: 0; |
| left: 0; |
| width: 100%; |
| height: 100%; |
| background-color: rgba(10, 10, 31, 0.7); |
| backdrop-filter: blur(5px); |
| z-index: -1; |
| } |
| |
| /* Rest of your dark theme CSS */ |
| /* ... */ |
| </style> |
| """, |
| unsafe_allow_html=True |
| ) |
| elif app_mode == "SAR to Optical Translation": |
| |
| st.markdown( |
| """ |
| <style> |
| .stApp { |
| background-color: #f8f9fa; |
| color: #333; |
| } |
| |
| .main { |
| background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80"); |
| background-size: cover; |
| background-position: center; |
| background-repeat: no-repeat; |
| background-attachment: fixed; |
| position: relative; |
| } |
| |
| .main::before { |
| content: ""; |
| position: absolute; |
| top: 0; |
| left: 0; |
| width: 100%; |
| height: 100%; |
| background-color: rgba(248, 249, 250, 0.7); |
| backdrop-filter: blur(5px); |
| z-index: -1; |
| } |
| |
| /* Adjust text colors for light theme */ |
| h1, h2, h3, h4, h5, h6 { |
| color: #333 !important; |
| } |
| |
| p, span, div, label { |
| color: #333 !important; |
| } |
| |
| /* Adjust card styling for light theme */ |
| .card { |
| background-color: rgba(255, 255, 255, 0.7) !important; |
| border: 1px solid rgba(147, 51, 234, 0.3) !important; |
| } |
| |
| /* Adjust metric card styling for light theme */ |
| .metric-card { |
| background-color: rgba(255, 255, 255, 0.7) !important; |
| } |
| |
| .metric-value { |
| color: #7c3aed !important; |
| } |
| |
| .metric-label { |
| color: #333 !important; |
| } |
| |
| /* Rest of your light theme adjustments */ |
| /* ... */ |
| </style> |
| """, |
| unsafe_allow_html=True |
| ) |
| |
| |
| def create_stars_html(num_stars=100): |
| stars_html = """<div class="stars">""" |
| for i in range(num_stars): |
| size = random.uniform(1, 3) |
| top = random.uniform(0, 100) |
| left = random.uniform(0, 100) |
| duration = random.uniform(3, 8) |
| opacity = random.uniform(0.2, 0.8) |
| |
| stars_html += f""" |
| <div class="star" style=" |
| width: {size}px; |
| height: {size}px; |
| top: {top}%; |
| left: {left}%; |
| --duration: {duration}s; |
| --opacity: {opacity}; |
| "></div> |
| """ |
| stars_html += "</div>" |
| return stars_html |
|
|
| |
| def add_logo(logo_path='assets/logo2.png'): |
| try: |
| with open(logo_path, "rb") as img_file: |
| logo_base64 = base64.b64encode(img_file.read()).decode() |
| st.markdown( |
| f"""<div style="position: absolute; top: 0.5rem; left: 1rem; z-index: 999;"> |
| <img src="data:image/png;base64,{logo_base64}" width="150px"></div>""", |
| unsafe_allow_html=True |
| ) |
| except FileNotFoundError: |
| st.warning(f"Logo file not found: {logo_path}") |
|
|
| |
|
|
| |
| st.markdown(create_stars_html(), unsafe_allow_html=True) |
|
|
|
|
| |
| with st.sidebar: |
| st.image('assets/logo2.png', width=150) |
| |
| |
| st.title("Applications") |
| app_mode = st.radio( |
| "Select Application", |
| ["SAR Colorization", "SAR to Optical Translation"] |
| ) |
| |
| st.session_state.app_mode = app_mode |
| |
| st.markdown("---") |
| st.title("Appearance") |
| theme = st.radio( |
| "Select Theme", |
| ["Dark", "Light"] |
| ) |
| set_app_style(st.session_state.app_mode) |
| |
| if theme.lower() != st.session_state.theme: |
| st.session_state.theme = theme.lower() |
| st.rerun() |
| |
| st.markdown("---") |
|
|
| |
| if st.session_state.app_mode == "SAR Colorization": |
| st.title("About") |
| st.markdown(""" |
| ### SAR Image Colorization |
| |
| This application uses deep learning models to segment and colorize Synthetic Aperture Radar (SAR) images into land cover classes. |
| |
| #### Features: |
| - Load pre-trained U-Net,DeepLabV3+ or SegNet models |
| - Process single SAR images |
| - Batch process multiple images |
| - Visualize Pixel Level Classification with ESA WorldCover color scheme |
| |
| #### Developed by: |
| Varun & Mokshyagna |
| (NRSC, ISRO) |
| |
| #### Technologies: |
| - TensorFlow/Keras |
| - Streamlit |
| - Rasterio |
| - OpenCV |
| |
| #### Version: |
| 1.0.0 |
| """) |
| |
| elif st.session_state.app_mode == "SAR to Optical Translation": |
| st.header("Model Configuration") |
| |
| |
| unet_weights_path = "models/unet_model.h5" |
| generator_path = "models/final_generator.keras" |
| |
| |
| st.info(f"U-Net Weights Path: {unet_weights_path}") |
| |
| use_generator = st.checkbox("Use Generator Model for Colorization", value=True) |
| if use_generator: |
| st.info(f"Generator Model Path: {generator_path}") |
| else: |
| generator_path = None |
| |
| |
| if st.button("Load Models"): |
| with st.spinner("Loading models..."): |
| gpu_status = setup_gpu() |
| st.info(gpu_status) |
| |
| try: |
| unet_model, generator_model = load_models(unet_weights_path, generator_path if use_generator else None) |
| st.session_state['unet_model'] = unet_model |
| st.session_state['generator_model'] = generator_model |
| st.success("Models loaded successfully!") |
| except Exception as e: |
| st.error(f"Error loading models: {e}") |
|
|
| |
| st.header("ESA WorldCover Classes") |
| class_info = { |
| 'Trees': [0, 100, 0], |
| 'Shrubland': [255, 165, 0], |
| 'Grassland': [144, 238, 144], |
| 'Cropland': [255, 255, 0], |
| 'Built-up': [255, 0, 0], |
| 'Bare': [139, 69, 19], |
| 'Snow': [255, 255, 255], |
| 'Water': [0, 0, 255], |
| 'Wetland': [0, 139, 139], |
| 'Mangroves': [0, 255, 0], |
| 'Moss': [220, 220, 220] |
| } |
| |
| for class_name, color in class_info.items(): |
| st.markdown( |
| f'<div style="display: flex; align-items: center;">' |
| f'<div style="width: 20px; height: 20px; background-color: rgb({color[0]}, {color[1]}, {color[2]}); margin-right: 10px;"></div>' |
| f'<span>{class_name}</span>' |
| f'</div>', |
| unsafe_allow_html=True |
| ) |
| |
| st.markdown("---") |
| st.markdown("© 2025 | All Rights Reserved") |
|
|
| |
| if st.session_state.app_mode == "SAR Colorization": |
| |
| st.markdown(""" |
| <div style="text-align: center; margin-bottom: 2rem; position: relative; z-index: 100;"> |
| <h1 style="color: #a78bfa; font-size: 3rem; font-weight: bold; text-shadow: 0 0 10px rgba(167, 139, 250, 0.5);"> |
| SAR Image Colorization |
| </h1> |
| <p style="color: #bfdbfe; font-size: 1.2rem;"> |
| Pixel Level Classification of Synthetic Aperture Radar images into land cover classes with deep learning |
| </p> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| |
| st.markdown("<div class='card'>", unsafe_allow_html=True) |
|
|
| |
| |
| tab1, tab2, tab3, tab4 = st.tabs(["📥 Load Model", "🖼️ Process Single Image", "📁 Process Multiple Images", "🔍 Sample Images"]) |
|
|
|
|
| |
| with tab1: |
| st.markdown("<h3 style='color: #a78bfa;'>Load Segmentation Model</h3>", unsafe_allow_html=True) |
| |
| |
| model_type = st.selectbox( |
| "Select model architecture", |
| ["U-Net", "DeepLabV3+", "SegNet"], |
| index=0, |
| help="Select the architecture of the model to load" |
| ) |
| |
| |
| st.session_state.segmentation.model_type = model_type.lower().replace('-', '') |
| |
| |
| model_paths = { |
| "unet": "models/unet_model.h5", |
| "deeplabv3+": "models/deeplabv3plus_model.h5", |
| "deeplabv3plus": "models/deeplabv3plus_model.h5", |
| "segnet": "models/segnet_model.h5" |
| } |
| |
| selected_model_path = model_paths[st.session_state.segmentation.model_type] |
| |
| |
| st.info(f"Model will be loaded from: {selected_model_path}") |
| |
| |
| if st.button("Load Model", key="load_model_btn"): |
| with st.spinner(f"Loading {model_type} model..."): |
| try: |
| |
| st.session_state.segmentation.load_trained_model(selected_model_path) |
| st.session_state.model_loaded = True |
| st.success("Model loaded successfully!") |
| except Exception as e: |
| st.error(f"Error loading model: {str(e)}") |
| |
| |
| if st.session_state.model_loaded: |
| st.markdown("<div class='card'>", unsafe_allow_html=True) |
| st.markdown("<h4 style='color: #a78bfa;'>Model Information</h4>", unsafe_allow_html=True) |
| |
| col1, col2, col3 = st.columns(3) |
| with col1: |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| |
| model_arch_map = { |
| 'unet': "U-Net", |
| 'deeplabv3plus': "DeepLabV3+", |
| 'segnet': "SegNet" |
| } |
| model_arch = model_arch_map.get(st.session_state.segmentation.model_type, "Unknown") |
| st.markdown(f"<p class='metric-value'>{model_arch}</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>Architecture</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| with col2: |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-value'>11</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>Land Cover Classes</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| with col3: |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-value'>256 x 256</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>Input Size</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| |
| st.markdown("<h4 style='color: #a78bfa; margin-top: 20px;'>Land Cover Classes</h4>", unsafe_allow_html=True) |
| legend_img = create_legend() |
| st.image(legend_img, use_container_width=True) |
| |
| st.markdown("</div>", unsafe_allow_html=True) |
| else: |
| st.info("Please load a model to continue.") |
|
|
| |
| with tab2: |
| st.markdown("<h3 style='color: #a78bfa;'>Process Single SAR Image</h3>", unsafe_allow_html=True) |
| |
| if not st.session_state.model_loaded: |
| st.warning("Please load a model in the 'Load Model' tab first.") |
| else: |
| st.markdown("<div class='upload-box'>", unsafe_allow_html=True) |
| col1, col2 = st.columns(2) |
| |
| with col1: |
| uploaded_file = st.file_uploader( |
| "Upload a SAR image (.tif or common image formats)", |
| type=["tif", "tiff", "png", "jpg", "jpeg"], |
| key="single_sar_uploader" |
| ) |
| |
| with col2: |
| |
| ground_truth_file = st.file_uploader( |
| "Upload ground truth (optional)", |
| type=["tif", "tiff", "png", "jpg", "jpeg"], |
| key="single_gt_uploader" |
| ) |
| |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| if uploaded_file is not None: |
| if st.button("Process Image", key="process_single_btn"): |
| with st.spinner("Processing image..."): |
| |
| try: |
| sar_data = st.session_state.segmentation.load_sar_data(uploaded_file.getvalue(), is_bytes=True) |
| |
| |
| sar_normalized = sar_data.copy() |
| min_val = np.min(sar_normalized) |
| max_val = np.max(sar_normalized) |
| sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) |
| |
| |
| prediction = st.session_state.segmentation.predict(sar_data) |
| |
| |
| if ground_truth_file is not None: |
| try: |
| |
| gt_data = st.session_state.segmentation.load_sar_data(ground_truth_file.getvalue(), is_bytes=True) |
| |
| |
| if np.max(sar_normalized) > 1 or np.min(sar_normalized) < 0: |
| |
| sar_for_viz = (sar_normalized.astype(np.float32) / 127.5) - 1 |
| else: |
| |
| sar_for_viz = sar_normalized |
| |
| |
| result_buf, colored_gt, colored_pred, overlay = visualize_with_ground_truth( |
| sar_for_viz, |
| gt_data, |
| prediction |
| ) |
| |
| |
| st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results with Ground Truth</h4>", unsafe_allow_html=True) |
| st.image(result_buf, use_container_width=True) |
| |
| |
| pred_class = np.argmax(prediction[0], axis=-1) |
| gt_class = gt_data[:,:,0].astype(np.int32) |
| |
| |
| if np.max(gt_class) > 10: |
| |
| gt_mapped = np.zeros_like(gt_class) |
| class_values = sorted(st.session_state.segmentation.class_definitions.values()) |
| for i, val in enumerate(class_values): |
| gt_mapped[gt_class == val] = i |
| gt_class = gt_mapped |
| |
| accuracy = np.mean(pred_class == gt_class) * 100 |
| |
| |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| st.markdown(f"<p class='metric-value'>{accuracy:.2f}%</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>Pixel Accuracy</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| |
| btn = st.download_button( |
| label="Download Result", |
| data=result_buf, |
| file_name="segmentation_result_with_gt.png", |
| mime="image/png", |
| key="download_single_result_with_gt" |
| ) |
| except Exception as e: |
| st.error(f"Error processing ground truth: {str(e)}") |
| |
| result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) |
| st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results</h4>", unsafe_allow_html=True) |
| st.image(result_img, use_container_width=True) |
| |
| |
| btn = st.download_button( |
| label="Download Result", |
| data=result_img, |
| file_name="segmentation_result.png", |
| mime="image/png", |
| key="download_single_result" |
| ) |
| else: |
| |
| result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) |
| st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results</h4>", unsafe_allow_html=True) |
| st.image(result_img, use_container_width=True) |
| |
| |
| btn = st.download_button( |
| label="Download Result", |
| data=result_img, |
| file_name="segmentation_result.png", |
| mime="image/png", |
| key="download_single_result" |
| ) |
| except Exception as e: |
| st.error(f"Error processing image: {str(e)}") |
|
|
| |
| with tab3: |
| st.markdown("<h3 style='color: #a78bfa;'>Process Multiple SAR Images</h3>", unsafe_allow_html=True) |
| |
| if not st.session_state.model_loaded: |
| st.warning("Please load a model in the 'Load Model' tab first.") |
| else: |
| st.markdown("<div class='upload-box'>", unsafe_allow_html=True) |
| |
| |
| use_gt = st.checkbox("Include ground truth data", value=False) |
| |
| col1, col2 = st.columns(2) |
| |
| with col1: |
| uploaded_files = st.file_uploader( |
| "Upload SAR images or a ZIP file containing images", |
| type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], |
| accept_multiple_files=True, |
| key="batch_sar_uploader" |
| ) |
| |
| |
| gt_files = None |
| if use_gt: |
| with col2: |
| gt_files = st.file_uploader( |
| "Upload ground truth images or a ZIP file (must match SAR filenames)", |
| type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], |
| accept_multiple_files=True, |
| key="batch_gt_uploader" |
| ) |
| st.info("Ground truth filenames should match SAR image filenames") |
| |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| col1, col2 = st.columns([3, 1]) |
| |
| with col1: |
| max_images = st.slider("Maximum number of images to display", min_value=1, max_value=20, value=10) |
| |
| with col2: |
| st.markdown("<br>", unsafe_allow_html=True) |
| process_btn = st.button("Process Images", key="process_multi_btn") |
| |
| if process_btn and uploaded_files: |
| |
| st.session_state.processed_images = [] |
| |
| |
| with st.spinner("Processing images..."): |
| |
| with tempfile.TemporaryDirectory() as temp_dir: |
| |
| sar_image_files = [] |
| gt_image_files = {} |
| |
| |
| for uploaded_file in uploaded_files: |
| if uploaded_file.name.lower().endswith('.zip'): |
| |
| zip_path = os.path.join(temp_dir, uploaded_file.name) |
| with open(zip_path, 'wb') as f: |
| f.write(uploaded_file.getvalue()) |
| |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| zip_ref.extractall(os.path.join(temp_dir, 'sar')) |
| |
| |
| for root, _, files in os.walk(os.path.join(temp_dir, 'sar')): |
| for file in files: |
| if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): |
| sar_image_files.append(os.path.join(root, file)) |
| else: |
| |
| file_path = os.path.join(temp_dir, 'sar', uploaded_file.name) |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| with open(file_path, 'wb') as f: |
| f.write(uploaded_file.getvalue()) |
| sar_image_files.append(file_path) |
| |
| |
| if use_gt and gt_files: |
| for gt_file in gt_files: |
| if gt_file.name.lower().endswith('.zip'): |
| |
| zip_path = os.path.join(temp_dir, gt_file.name) |
| with open(zip_path, 'wb') as f: |
| f.write(gt_file.getvalue()) |
| |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| zip_ref.extractall(os.path.join(temp_dir, 'gt')) |
| |
| |
| for root, _, files in os.walk(os.path.join(temp_dir, 'gt')): |
| for file in files: |
| if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): |
| |
| gt_path = os.path.join(root, file) |
| gt_image_files[os.path.basename(file)] = gt_path |
| else: |
| |
| file_path = os.path.join(temp_dir, 'gt', gt_file.name) |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| with open(file_path, 'wb') as f: |
| f.write(gt_file.getvalue()) |
| gt_image_files[os.path.basename(gt_file.name)] = file_path |
| |
| |
| if len(sar_image_files) > max_images: |
| st.info(f"Found {len(sar_image_files)} images. Randomly selecting {max_images} images to display.") |
| sar_image_files = random.sample(sar_image_files, max_images) |
| |
| |
| progress_bar = st.progress(0) |
| |
| |
| if use_gt and gt_image_files: |
| overall_accuracy = [] |
| |
| for i, image_path in enumerate(sar_image_files): |
| try: |
| |
| progress_bar.progress((i + 1) / len(sar_image_files)) |
| |
| |
| sar_data = st.session_state.segmentation.load_sar_data(image_path) |
| |
| |
| sar_normalized = sar_data.copy() |
| min_val = np.min(sar_normalized) |
| max_val = np.max(sar_normalized) |
| sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) |
| |
| |
| prediction = st.session_state.segmentation.predict(sar_data) |
| |
| |
| image_basename = os.path.basename(image_path) |
| has_gt = image_basename in gt_image_files |
| |
| if has_gt and use_gt: |
| |
| gt_path = gt_image_files[image_basename] |
| gt_data = st.session_state.segmentation.load_sar_data(gt_path) |
| |
| if np.max(sar_normalized) > 1 or np.min(sar_normalized) < 0: |
| |
| sar_for_viz = (sar_normalized.astype(np.float32) / 127.5) - 1 |
| else: |
| |
| sar_for_viz = sar_normalized |
|
|
| |
| result_buf, colored_gt, colored_pred, overlay = visualize_with_ground_truth( |
| sar_for_viz, |
| gt_data, |
| prediction |
| ) |
| |
| |
| pred_class = np.argmax(prediction[0], axis=-1) |
| gt_class = gt_data[:,:,0].astype(np.int32) |
| |
| |
| if np.max(gt_class) > 10: |
| |
| gt_mapped = np.zeros_like(gt_class) |
| class_values = sorted(st.session_state.segmentation.class_definitions.values()) |
| for i, val in enumerate(class_values): |
| gt_mapped[gt_class == val] = i |
| gt_class = gt_mapped |
| |
| accuracy = np.mean(pred_class == gt_class) * 100 |
| overall_accuracy.append(accuracy) |
| |
| |
| st.session_state.processed_images.append({ |
| 'filename': os.path.basename(image_path), |
| 'result': result_buf, |
| 'accuracy': accuracy |
| }) |
| else: |
| |
| result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) |
| |
| |
| st.session_state.processed_images.append({ |
| 'filename': os.path.basename(image_path), |
| 'result': result_img |
| }) |
| except Exception as e: |
| st.error(f"Error processing {os.path.basename(image_path)}: {str(e)}") |
| |
| |
| progress_bar.empty() |
| |
| |
| if st.session_state.processed_images: |
| st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results</h4>", unsafe_allow_html=True) |
| |
| |
| if use_gt and 'overall_accuracy' in locals() and overall_accuracy: |
| avg_accuracy = np.mean(overall_accuracy) |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| st.markdown(f"<p class='metric-value'>{avg_accuracy:.2f}%</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>Average Pixel Accuracy</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| |
| zip_buffer = io.BytesIO() |
| with zipfile.ZipFile(zip_buffer, 'w') as zip_file: |
| for i, img_data in enumerate(st.session_state.processed_images): |
| zip_file.writestr(f"result_{i+1}_{img_data['filename']}.png", img_data['result'].getvalue()) |
| |
| |
| st.download_button( |
| label="Download All Results", |
| data=zip_buffer.getvalue(), |
| file_name="segmentation_results.zip", |
| mime="application/zip", |
| key="download_all_results" |
| ) |
| |
| |
| for i, img_data in enumerate(st.session_state.processed_images): |
| st.markdown(f"<h5 style='color: #bfdbfe;'>Image: {img_data['filename']}</h5>", unsafe_allow_html=True) |
| |
| |
| if 'accuracy' in img_data: |
| st.markdown(f"<p style='color: #a78bfa;'>Pixel Accuracy: {img_data['accuracy']:.2f}%</p>", unsafe_allow_html=True) |
| |
| st.image(img_data['result'], use_container_width=True) |
| st.markdown("<hr style='border-color: rgba(147, 51, 234, 0.3);'>", unsafe_allow_html=True) |
| else: |
| st.warning("No images were successfully processed.") |
| elif process_btn: |
| st.warning("Please upload at least one image file or ZIP archive.") |
|
|
| |
| with tab4: |
| st.markdown("<h3 style='color: #a78bfa;'>Sample Images</h3>", unsafe_allow_html=True) |
| |
| if not st.session_state.model_loaded: |
| st.warning("Please load a model in the 'Load Model' tab first.") |
| else: |
| st.markdown("<div class='card'>", unsafe_allow_html=True) |
| |
| |
| import os |
| sample_dir = "samples/SAR" |
| if os.path.exists(sample_dir): |
| sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg'))] |
| else: |
| os.makedirs(sample_dir, exist_ok=True) |
| os.makedirs("samples/OPTICAL", exist_ok=True) |
| os.makedirs("samples/LABELS", exist_ok=True) |
| sample_files = [] |
| |
| if sample_files: |
| |
| selected_sample = st.selectbox( |
| "Select a sample image", |
| sample_files, |
| key="sample_selector" |
| ) |
| |
| |
| col1, col2, col3 = st.columns(3) |
| |
| with col1: |
| st.subheader("SAR Image") |
| sar_path = os.path.join("samples/SAR", selected_sample) |
| display_image(sar_path) |
|
|
| |
| with col2: |
| st.subheader("Optical Image (Ground Truth)") |
| |
| opt_path = os.path.join("samples/OPTICAL", selected_sample) |
| if os.path.exists(opt_path): |
| display_image(opt_path) |
| else: |
| st.info("No matching optical image found") |
| |
| |
| with col3: |
| st.subheader("Label Image") |
| samples_dir = "samples" |
| |
| |
| possible_label_dirs = [ |
| os.path.join(samples_dir, "labels"), |
| os.path.join(samples_dir, "label"), |
| os.path.join(samples_dir, "LABELS"), |
| os.path.join(samples_dir, "LABEL"), |
| os.path.join(samples_dir, "Labels"), |
| os.path.join(samples_dir, "Label"), |
| os.path.join(samples_dir, "gt"), |
| os.path.join(samples_dir, "GT"), |
| os.path.join(samples_dir, "ground_truth"), |
| os.path.join(samples_dir, "groundtruth") |
| ] |
| |
| |
| label_path = None |
| base_name = os.path.splitext(selected_sample)[0] |
| |
| |
| for dir_path in possible_label_dirs: |
| if not os.path.exists(dir_path): |
| continue |
| |
| |
| exact_path = os.path.join(dir_path, selected_sample) |
| if os.path.exists(exact_path): |
| label_path = exact_path |
| break |
| |
| |
| for ext in ['.tif', '.tiff', '.png', '.jpg', '.jpeg', '.TIF', '.TIFF', '.PNG', '.JPG', '.JPEG']: |
| test_path = os.path.join(dir_path, base_name + ext) |
| if os.path.exists(test_path): |
| label_path = test_path |
| break |
| |
| |
| if not label_path: |
| for file in os.listdir(dir_path): |
| if os.path.splitext(file)[0].lower() == base_name.lower(): |
| label_path = os.path.join(dir_path, file) |
| break |
| |
| if label_path: |
| break |
| |
| |
| |
| if label_path and os.path.exists(label_path): |
| try: |
| |
| if label_path.lower().endswith(('.tif', '.tiff')): |
| with rasterio.open(label_path) as src: |
| label_data = src.read(1) |
| |
| |
| colors = get_esa_colors() |
| colored_label = np.zeros((label_data.shape[0], label_data.shape[1], 3), dtype=np.uint8) |
| |
| |
| for class_idx, color in colors.items(): |
| |
| if np.max(label_data) > 10: |
| |
| class_values = sorted(st.session_state.segmentation.class_definitions.values()) |
| for i, val in enumerate(class_values): |
| if class_idx == i: |
| colored_label[label_data == val] = color |
| else: |
| |
| colored_label[label_data == class_idx] = color |
| |
| st.image(colored_label, use_container_width=True) |
| else: |
| |
| display_image(label_path) |
| except Exception as e: |
| st.error(f"Error displaying label image: {str(e)}") |
| |
| display_image(label_path) |
| else: |
| st.info("No matching label image found") |
|
|
|
|
| |
| if st.button("Process Selected Sample", key="process_sample_btn"): |
| with st.spinner("Processing sample image..."): |
| try: |
| |
| sar_data = st.session_state.segmentation.load_sar_data(sar_path) |
| |
| |
| sar_normalized = sar_data.copy() |
| min_val = np.min(sar_normalized) |
| max_val = np.max(sar_normalized) |
| sar_normalized = ((sar_normalized - min_val) / (max_val - min_val) * 255).astype(np.uint8) |
| |
| |
| prediction = st.session_state.segmentation.predict(sar_data) |
| |
| |
| if os.path.exists(label_path): |
| |
| gt_data = st.session_state.segmentation.load_sar_data(label_path) |
| |
| |
| if np.max(sar_normalized) > 1 or np.min(sar_normalized) < 0: |
| |
| sar_for_viz = (sar_normalized.astype(np.float32) / 127.5) - 1 |
| else: |
| |
| sar_for_viz = sar_normalized |
| |
| |
| result_buf, colored_gt, colored_pred, overlay = visualize_with_ground_truth( |
| sar_for_viz, |
| gt_data, |
| prediction |
| ) |
| |
| |
| st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results with Ground Truth</h4>", unsafe_allow_html=True) |
| st.image(result_buf, use_container_width=True) |
| |
| |
| pred_class = np.argmax(prediction[0], axis=-1) |
| gt_class = gt_data[:,:,0].astype(np.int32) |
| |
| |
| if np.max(gt_class) > 10: |
| |
| gt_mapped = np.zeros_like(gt_class) |
| class_values = sorted(st.session_state.segmentation.class_definitions.values()) |
| for i, val in enumerate(class_values): |
| gt_mapped[gt_class == val] = i |
| gt_class = gt_mapped |
| |
| accuracy = np.mean(pred_class == gt_class) * 100 |
| |
| |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| st.markdown(f"<p class='metric-value'>{accuracy:.2f}%</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>Pixel Accuracy</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| |
| btn = st.download_button( |
| label="Download Result", |
| data=result_buf, |
| file_name=f"sample_result_{selected_sample}.png", |
| mime="image/png", |
| key="download_sample_result_with_gt" |
| ) |
| else: |
| |
| result_img = visualize_prediction(prediction, np.expand_dims(sar_normalized, axis=-1)) |
| st.markdown("<h4 style='color: #a78bfa;'>Segmentation Results</h4>", unsafe_allow_html=True) |
| st.image(result_img, use_container_width=True) |
| |
| |
| btn = st.download_button( |
| label="Download Result", |
| data=result_img, |
| file_name=f"sample_result_{selected_sample}.png", |
| mime="image/png", |
| key="download_sample_result" |
| ) |
| except Exception as e: |
| st.error(f"Error processing sample image: {str(e)}") |
| else: |
| st.info("No sample images found. Please add some images to the 'samples/SAR' directory.") |
| |
| |
|
|
| |
| st.markdown("</div>", unsafe_allow_html=True) |
|
|
| elif st.session_state.app_mode == "SAR to Optical Translation": |
| |
| st.markdown(""" |
| <div style="text-align: center; margin-bottom: 2rem; position: relative; z-index: 100;"> |
| <h1 style="color: #a78bfa; font-size: 3rem; font-weight: bold; text-shadow: 0 0 10px rgba(167, 139, 250, 0.5);"> |
| SAR to Optical Translation |
| </h1> |
| <p style="color: #bfdbfe; font-size: 1.2rem;"> |
| Convert Synthetic Aperture Radar images to optical-like imagery using deep learning |
| </p> |
| </div> |
| """, unsafe_allow_html=True) |
| |
| |
| st.markdown("<div class='card'>", unsafe_allow_html=True) |
| |
| |
| models_loaded = 'unet_model' in st.session_state |
| |
| if not models_loaded: |
| st.warning("Please load the models from the sidebar first.") |
| else: |
| st.success("Models loaded successfully! You can now process SAR images.") |
| |
| |
| |
| tab1, tab2, tab3 = st.tabs(["Process Single Image", "Batch Processing", "Sample Images"]) |
|
|
| |
| with tab1: |
| st.markdown("<h3 style='color: #a78bfa;'>Upload SAR Image</h3>", unsafe_allow_html=True) |
| |
| |
| |
| col1, col2 = st.columns(2) |
| |
| with col1: |
| st.markdown("<div class='upload-box'>", unsafe_allow_html=True) |
| uploaded_file = st.file_uploader( |
| "Upload a SAR image (.tif or common image formats)", |
| type=["tif", "tiff", "png", "jpg", "jpeg"], |
| key="sar_optical_uploader" |
| ) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| |
| with col2: |
| st.markdown("<div class='upload-box'>", unsafe_allow_html=True) |
| gt_file = st.file_uploader( |
| "Upload ground truth optical image (optional)", |
| type=["tif", "tiff", "png", "jpg", "jpeg"], |
| key="optical_gt_uploader" |
| ) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| if uploaded_file is not None: |
| |
| if st.button("Generate Optical-like Image", key="generate_optical_btn"): |
| with st.spinner("Processing image..."): |
| try: |
| |
| sar_batch, sar_image = load_sar_image(uploaded_file) |
| |
| if sar_batch is not None: |
| |
| seg_mask, colorized = process_image( |
| sar_batch, |
| st.session_state['unet_model'], |
| st.session_state.get('generator_model') |
| ) |
| |
| |
| sar_rgb, colored_pred, overlay, colorized_img = visualize_results( |
| sar_image, seg_mask, colorized |
| ) |
| |
| |
| st.header("Results") |
| |
| |
| if gt_file is not None: |
| try: |
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file: |
| tmp_file.write(gt_file.getbuffer()) |
| tmp_file_path = tmp_file.name |
|
|
| try: |
| |
| with rasterio.open(tmp_file_path) as src: |
| gt_image = src.read() |
| |
| st.info(f"Ground truth shape: {gt_image.shape}, dtype: {gt_image.dtype}, min: {np.min(gt_image)}, max: {np.max(gt_image)}") |
| |
| if gt_image.shape[0] == 3: |
| gt_image = np.transpose(gt_image, (1, 2, 0)) |
| else: |
| gt_image = src.read(1) |
| |
| if np.all(gt_image == 0) or np.all(gt_image == 1): |
| st.warning("Ground truth image appears to be blank (all zeros or ones)") |
| |
| |
| gt_image = np.expand_dims(gt_image, axis=-1) |
| gt_image = np.repeat(gt_image, 3, axis=-1) |
| except Exception as rasterio_error: |
| st.warning(f"Rasterio failed: {str(rasterio_error)}. Trying PIL...") |
| try: |
| |
| gt_image = np.array(Image.open(tmp_file_path).convert('RGB')) |
| |
| st.info(f"Ground truth shape (PIL): {gt_image.shape}, dtype: {gt_image.dtype}, min: {np.min(gt_image)}, max: {np.max(gt_image)}") |
| |
| |
| if np.all(gt_image > 250): |
| st.warning("Ground truth image appears to be all white") |
| except Exception as pil_error: |
| st.error(f"Both rasterio and PIL failed to load the ground truth: {str(pil_error)}") |
| raise |
|
|
| |
| os.unlink(tmp_file_path) |
|
|
| |
| if gt_image.shape[:2] != (256, 256): |
| gt_image = cv2.resize(gt_image, (256, 256)) |
|
|
| |
| if gt_image.dtype != np.uint8: |
| if np.max(gt_image) > 1.0 and np.max(gt_image) <= 255: |
| gt_image = gt_image.astype(np.uint8) |
| elif np.max(gt_image) <= 1.0: |
| gt_image = (gt_image * 255).astype(np.uint8) |
| else: |
| |
| gt_min, gt_max = np.min(gt_image), np.max(gt_image) |
| gt_image = ((gt_image - gt_min) / (gt_max - gt_min) * 255).astype(np.uint8) |
|
|
| |
| |
| fig, axes = plt.subplots(1, 4, figsize=(16, 4)) |
| |
| |
| axes[0].imshow(sar_rgb, cmap='gray') |
| axes[0].set_title('Original SAR', color='white') |
| axes[0].axis('off') |
| |
| |
| axes[1].imshow(gt_image) |
| axes[1].set_title('Ground Truth', color='white') |
| axes[1].axis('off') |
| |
| |
| axes[2].imshow(colored_pred) |
| axes[2].set_title('Segmentation', color='white') |
| axes[2].axis('off') |
| |
| |
| if colorized_img is not None: |
| |
| colorized_display = (colorized_img * 0.5) + 0.5 |
| axes[3].imshow(colorized_display) |
| else: |
| axes[3].imshow(overlay) |
| axes[3].set_title('Generated Image', color='white') |
| axes[3].axis('off') |
| |
| |
| fig.patch.set_facecolor('#0a0a1f') |
| for ax in axes: |
| ax.set_facecolor('#0a0a1f') |
| |
| plt.tight_layout() |
| |
| |
| st.pyplot(fig) |
| |
| |
| if colorized_img is not None: |
| |
| colorized_norm = (colorized_img * 0.5) + 0.5 |
| gt_norm = gt_image.astype(np.float32) / 255.0 |
| |
| |
| mse = np.mean((colorized_norm - gt_norm) ** 2) |
| psnr = 20 * np.log10(1.0 / np.sqrt(mse)) |
| |
| |
| from skimage.metrics import structural_similarity as ssim |
|
|
| try: |
| |
| min_dim = min(colorized_norm.shape[0], colorized_norm.shape[1]) |
| win_size = min(7, min_dim - (min_dim % 2) + 1) |
| |
| ssim_value = ssim( |
| colorized_norm, |
| gt_norm, |
| win_size=win_size, |
| channel_axis=2, |
| data_range=1.0 |
| ) |
| except Exception as e: |
| st.warning(f"Could not calculate SSIM: {str(e)}") |
| ssim_value = 0.0 |
|
|
| |
| |
| col1, col2 = st.columns(2) |
| with col1: |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| st.markdown(f"<p class='metric-value'>{psnr:.2f}</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>PSNR (dB)</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| with col2: |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| st.markdown(f"<p class='metric-value'>{ssim_value:.4f}</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>SSIM</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| except Exception as e: |
| st.error(f"Error processing ground truth: {str(e)}") |
| |
| col1, col2, col3 = st.columns(3) |
| |
| with col1: |
| st.subheader("Original SAR Image") |
| st.image(sar_rgb, use_container_width=True) |
| |
| with col2: |
| st.subheader("Predicted Segmentation") |
| st.image(colored_pred, use_container_width=True) |
| |
| with col3: |
| st.subheader("Colorized SAR") |
| st.image(overlay, use_container_width=True) |
| else: |
| |
| col1, col2, col3 = st.columns(3) |
| |
| with col1: |
| st.subheader("Original SAR Image") |
| st.image(sar_rgb, use_container_width=True) |
| |
| with col2: |
| st.subheader("Predicted Segmentation") |
| st.image(colored_pred, use_container_width=True) |
| |
| with col3: |
| st.subheader("Colorized SAR") |
| st.image(overlay, use_container_width=True) |
| |
| |
| if colorized_img is not None: |
| st.header("Translated Optical Image") |
| |
| colorized_display = (colorized_img * 0.5) + 0.5 |
| |
| |
| fig, ax = plt.subplots(figsize=(6, 6)) |
| ax.imshow(colorized_display) |
| ax.axis('off') |
| |
| |
| st.pyplot(fig, use_container_width=False) |
| |
| |
| col1, col2 = st.columns(2) |
| |
| with col1: |
| |
| seg_buf = io.BytesIO() |
| plt.imsave(seg_buf, colored_pred, format='png') |
| seg_buf.seek(0) |
| |
| st.download_button( |
| label="Download Segmentation", |
| data=seg_buf, |
| file_name="segmentation.png", |
| mime="image/png", |
| key="download_seg" |
| ) |
| |
| with col2: |
| |
| gen_buf = io.BytesIO() |
| plt.imsave(gen_buf, colorized_display, format='png') |
| gen_buf.seek(0) |
| |
| st.download_button( |
| label="Download Optical-like Image", |
| data=gen_buf, |
| file_name="optical_like.png", |
| mime="image/png", |
| key="download_optical" |
| ) |
| except Exception as e: |
| st.error(f"Error processing image: {str(e)}") |
| |
| |
| with tab2: |
| st.markdown("<h3 style='color: #a78bfa;'>Batch Process SAR Images</h3>", unsafe_allow_html=True) |
| |
| st.markdown("<div class='upload-box'>", unsafe_allow_html=True) |
| |
| |
| use_gt = st.checkbox("Include ground truth data", value=False) |
| |
| col1, col2 = st.columns(2) |
| |
| with col1: |
| batch_files = st.file_uploader( |
| "Upload SAR images or a ZIP file containing images", |
| type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], |
| accept_multiple_files=True, |
| key="batch_sar_optical_uploader" |
| ) |
| |
| |
| batch_gt_files = None |
| if use_gt: |
| with col2: |
| batch_gt_files = st.file_uploader( |
| "Upload ground truth optical images or a ZIP file (must match SAR filenames)", |
| type=["tif", "tiff", "png", "jpg", "jpeg", "zip"], |
| accept_multiple_files=True, |
| key="batch_optical_gt_uploader" |
| ) |
| st.info("Ground truth filenames should match SAR image filenames") |
| |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| col1, col2 = st.columns([3, 1]) |
| |
| with col1: |
| max_images = st.slider("Maximum number of images to display", min_value=1, max_value=20, value=5) |
| |
| with col2: |
| st.markdown("<br>", unsafe_allow_html=True) |
| batch_process_btn = st.button("Process Images", key="batch_process_btn") |
| |
| if batch_process_btn and batch_files: |
| |
| if 'batch_results' not in st.session_state: |
| st.session_state.batch_results = [] |
| else: |
| st.session_state.batch_results = [] |
| |
| |
| with st.spinner("Processing images..."): |
| |
| with tempfile.TemporaryDirectory() as temp_dir: |
| |
| sar_image_files = [] |
| gt_image_files = {} |
| |
| |
| for uploaded_file in batch_files: |
| if uploaded_file.name.lower().endswith('.zip'): |
| |
| zip_path = os.path.join(temp_dir, uploaded_file.name) |
| with open(zip_path, 'wb') as f: |
| f.write(uploaded_file.getvalue()) |
| |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| zip_ref.extractall(os.path.join(temp_dir, 'sar')) |
| |
| |
| for root, _, files in os.walk(os.path.join(temp_dir, 'sar')): |
| for file in files: |
| if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): |
| sar_image_files.append(os.path.join(root, file)) |
| else: |
| |
| file_path = os.path.join(temp_dir, 'sar', uploaded_file.name) |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| with open(file_path, 'wb') as f: |
| f.write(uploaded_file.getvalue()) |
| sar_image_files.append(file_path) |
| |
| |
| if use_gt and batch_gt_files: |
| for gt_file in batch_gt_files: |
| if gt_file.name.lower().endswith('.zip'): |
| |
| zip_path = os.path.join(temp_dir, gt_file.name) |
| with open(zip_path, 'wb') as f: |
| f.write(gt_file.getvalue()) |
| |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| zip_ref.extractall(os.path.join(temp_dir, 'gt')) |
| |
| |
| for root, _, files in os.walk(os.path.join(temp_dir, 'gt')): |
| for file in files: |
| if file.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')): |
| |
| gt_path = os.path.join(root, file) |
| gt_image_files[os.path.basename(file)] = gt_path |
| else: |
| |
| file_path = os.path.join(temp_dir, 'gt', gt_file.name) |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| with open(file_path, 'wb') as f: |
| f.write(gt_file.getvalue()) |
| gt_image_files[os.path.basename(gt_file.name)] = file_path |
| |
| |
| if len(sar_image_files) > max_images: |
| st.info(f"Found {len(sar_image_files)} images. Randomly selecting {max_images} images to display.") |
| sar_image_files = random.sample(sar_image_files, max_images) |
| |
| |
| progress_bar = st.progress(0) |
| |
| |
| if use_gt and gt_image_files: |
| overall_psnr = [] |
| overall_ssim = [] |
| |
| for i, image_path in enumerate(sar_image_files): |
| try: |
| |
| progress_bar.progress((i + 1) / len(sar_image_files)) |
| |
| |
| with open(image_path, 'rb') as f: |
| file_bytes = f.read() |
| |
| sar_batch, sar_image = load_sar_image(io.BytesIO(file_bytes)) |
| |
| if sar_batch is not None: |
| |
| seg_mask, colorized = process_image( |
| sar_batch, |
| st.session_state['unet_model'], |
| st.session_state.get('generator_model') |
| ) |
| |
| |
| sar_rgb, colored_pred, overlay, colorized_img = visualize_results( |
| sar_image, seg_mask, colorized |
| ) |
| |
| |
| image_basename = os.path.basename(image_path) |
| has_gt = image_basename in gt_image_files |
| |
| if has_gt and use_gt: |
| |
| gt_path = gt_image_files[image_basename] |
| try: |
| |
| with rasterio.open(gt_path) as src: |
| gt_image = src.read() |
| |
| if gt_image.shape[0] == 3: |
| gt_image = np.transpose(gt_image, (1, 2, 0)) |
| else: |
| gt_image = src.read(1) |
| |
| |
| gt_image = np.expand_dims(gt_image, axis=-1) |
| gt_image = np.repeat(gt_image, 3, axis=-1) |
| except Exception as rasterio_error: |
| try: |
| |
| gt_image = np.array(Image.open(gt_path).convert('RGB')) |
| except Exception as pil_error: |
| st.error(f"Both rasterio and PIL failed to load the ground truth: {str(pil_error)}") |
| raise |
|
|
| |
| if gt_image.shape[:2] != (256, 256): |
| gt_image = cv2.resize(gt_image, (256, 256)) |
|
|
| |
| if gt_image.dtype != np.uint8: |
| if np.max(gt_image) > 1.0 and np.max(gt_image) <= 255: |
| gt_image = gt_image.astype(np.uint8) |
| elif np.max(gt_image) <= 1.0: |
| gt_image = (gt_image * 255).astype(np.uint8) |
| else: |
| |
| gt_min, gt_max = np.min(gt_image), np.max(gt_image) |
| gt_image = ((gt_image - gt_min) / (gt_max - gt_min) * 255).astype(np.uint8) |
|
|
| |
| |
| fig, axes = plt.subplots(1, 4, figsize=(16, 4)) |
| |
| |
| axes[0].imshow(sar_rgb, cmap='gray') |
| axes[0].set_title('Original SAR', color='white') |
| axes[0].axis('off') |
| |
| |
| axes[1].imshow(gt_image) |
| axes[1].set_title('Ground Truth', color='white') |
| axes[1].axis('off') |
| |
| |
| axes[2].imshow(colored_pred) |
| axes[2].set_title('Segmentation', color='white') |
| axes[2].axis('off') |
| |
| |
| if colorized_img is not None: |
| |
| colorized_display = (colorized_img * 0.5) + 0.5 |
| axes[3].imshow(colorized_display) |
| else: |
| axes[3].imshow(overlay) |
| axes[3].set_title('Generated Image', color='white') |
| axes[3].axis('off') |
| |
| |
| fig.patch.set_facecolor('#0a0a1f') |
| for ax in axes: |
| ax.set_facecolor('#0a0a1f') |
| |
| plt.tight_layout() |
| |
| |
| result_buf = io.BytesIO() |
| plt.savefig(result_buf, format='png', facecolor='#0a0a1f', bbox_inches='tight') |
| result_buf.seek(0) |
| plt.close(fig) |
| |
| |
| metrics = {'psnr': 0.0, 'ssim': 0.0} |
| if colorized_img is not None: |
| try: |
| |
| colorized_norm = (colorized_img * 0.5) + 0.5 |
| gt_norm = gt_image.astype(np.float32) / 255.0 |
| |
| |
| mse = np.mean((colorized_norm - gt_norm) ** 2) |
| if mse > 0: |
| psnr = 20 * np.log10(1.0 / np.sqrt(mse)) |
| metrics['psnr'] = psnr |
| overall_psnr.append(psnr) |
| |
| |
| from skimage.metrics import structural_similarity as ssim |
| |
| |
| min_dim = min(colorized_norm.shape[0], colorized_norm.shape[1]) |
| win_size = min(7, min_dim - (min_dim % 2) + 1) |
| |
| ssim_value = ssim( |
| colorized_norm, |
| gt_norm, |
| win_size=win_size, |
| channel_axis=2, |
| data_range=1.0 |
| ) |
| metrics['ssim'] = ssim_value |
| overall_ssim.append(ssim_value) |
| except Exception as e: |
| st.warning(f"Could not calculate metrics for {os.path.basename(image_path)}: {str(e)}") |
|
|
| |
| |
| gen_buf = io.BytesIO() |
| if colorized_img is not None: |
| plt.imsave(gen_buf, colorized_display, format='png') |
| else: |
| plt.imsave(gen_buf, overlay, format='png') |
| gen_buf.seek(0) |
| |
| |
| st.session_state.batch_results.append({ |
| 'filename': os.path.basename(image_path), |
| 'result': result_buf, |
| 'generated': gen_buf, |
| 'metrics': metrics |
| }) |
| else: |
| |
| fig, axes = plt.subplots(1, 3, figsize=(12, 4)) |
| |
| |
| axes[0].imshow(sar_rgb, cmap='gray') |
| axes[0].set_title('Original SAR', color='white') |
| axes[0].axis('off') |
| |
| |
| axes[1].imshow(colored_pred) |
| axes[1].set_title('Segmentation', color='white') |
| axes[1].axis('off') |
| |
| |
| if colorized_img is not None: |
| |
| colorized_display = (colorized_img * 0.5) + 0.5 |
| axes[2].imshow(colorized_display) |
| else: |
| axes[2].imshow(overlay) |
| axes[2].set_title('Generated Image', color='white') |
| axes[2].axis('off') |
| |
| |
| fig.patch.set_facecolor('#0a0a1f') |
| for ax in axes: |
| ax.set_facecolor('#0a0a1f') |
| |
| plt.tight_layout() |
| |
| |
| result_buf = io.BytesIO() |
| plt.savefig(result_buf, format='png', facecolor='#0a0a1f', bbox_inches='tight') |
| result_buf.seek(0) |
| plt.close(fig) |
| |
| |
| gen_buf = io.BytesIO() |
| if colorized_img is not None: |
| plt.imsave(gen_buf, colorized_display, format='png') |
| else: |
| plt.imsave(gen_buf, overlay, format='png') |
| gen_buf.seek(0) |
| |
| |
| st.session_state.batch_results.append({ |
| 'filename': os.path.basename(image_path), |
| 'result': result_buf, |
| 'generated': gen_buf |
| }) |
| except Exception as e: |
| st.error(f"Error processing {os.path.basename(image_path)}: {str(e)}") |
| |
| |
| progress_bar.empty() |
| |
| |
| if st.session_state.batch_results: |
| st.markdown("<h4 style='color: #a78bfa;'>Translation Results</h4>", unsafe_allow_html=True) |
| |
| |
| if use_gt and 'overall_psnr' in locals() and overall_psnr: |
| avg_psnr = np.mean(overall_psnr) |
| avg_ssim = np.mean(overall_ssim) |
| |
| col1, col2 = st.columns(2) |
| with col1: |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| st.markdown(f"<p class='metric-value'>{avg_psnr:.2f}</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>Average PSNR (dB)</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| with col2: |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| st.markdown(f"<p class='metric-value'>{avg_ssim:.4f}</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>Average SSIM</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| |
| zip_buffer = io.BytesIO() |
| with zipfile.ZipFile(zip_buffer, 'w') as zip_file: |
| for i, result in enumerate(st.session_state.batch_results): |
| |
| zip_file.writestr(f"result_{i+1}_{result['filename']}.png", result['result'].getvalue()) |
| |
| zip_file.writestr(f"generated_{i+1}_{result['filename']}.png", result['generated'].getvalue()) |
| |
| |
| st.download_button( |
| label="Download All Results", |
| data=zip_buffer.getvalue(), |
| file_name="translation_results.zip", |
| mime="application/zip", |
| key="download_all_translation_results" |
| ) |
| |
| |
| for i, result in enumerate(st.session_state.batch_results): |
| st.markdown(f"<h5 style='color: #bfdbfe;'>Image: {result['filename']}</h5>", unsafe_allow_html=True) |
| |
| |
| if 'metrics' in result: |
| col1, col2 = st.columns(2) |
| with col1: |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| if 'psnr' in result['metrics']: |
| st.markdown(f"<p class='metric-value'>{result['metrics']['psnr']:.2f}</p>", unsafe_allow_html=True) |
| else: |
| st.markdown("<p class='metric-value'>N/A</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>PSNR (dB)</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| with col2: |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| if 'ssim' in result['metrics']: |
| st.markdown(f"<p class='metric-value'>{result['metrics']['ssim']:.4f}</p>", unsafe_allow_html=True) |
| else: |
| st.markdown("<p class='metric-value'>N/A</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>SSIM</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
|
|
| |
| st.image(result['result'], use_container_width=True) |
| |
| |
| col1, col2 = st.columns(2) |
| with col1: |
| st.download_button( |
| label="Download Visualization", |
| data=result['result'].getvalue(), |
| file_name=f"result_{result['filename']}.png", |
| mime="image/png", |
| key=f"download_viz_{i}" |
| ) |
| |
| with col2: |
| st.download_button( |
| label="Download Generated Image", |
| data=result['generated'].getvalue(), |
| file_name=f"generated_{result['filename']}.png", |
| mime="image/png", |
| key=f"download_gen_{i}" |
| ) |
| |
| st.markdown("<hr style='border-color: rgba(147, 51, 234, 0.3);'>", unsafe_allow_html=True) |
| else: |
| st.warning("No images were successfully processed.") |
| elif batch_process_btn: |
| st.warning("Please upload at least one image file or ZIP archive.") |
| |
| with tab3: |
| st.markdown("<h3 style='color: #a78bfa;'>Sample Images</h3>", unsafe_allow_html=True) |
| st.markdown("<div class='card'>", unsafe_allow_html=True) |
| |
| |
| import os |
| sample_dir = "samples/SAR" |
| if os.path.exists(sample_dir): |
| sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg'))] |
| else: |
| os.makedirs(sample_dir, exist_ok=True) |
| os.makedirs("samples/OPTICAL", exist_ok=True) |
| os.makedirs("samples/LABELS", exist_ok=True) |
| sample_files = [] |
| |
| if sample_files and 'unet_model' in st.session_state: |
| |
| selected_sample = st.selectbox( |
| "Select a sample image", |
| sample_files, |
| key="optical_sample_selector" |
| ) |
| |
| |
| col1, col2 = st.columns(2) |
| |
| with col1: |
| st.subheader("SAR Image") |
| sar_path = os.path.join("samples/SAR", selected_sample) |
| display_image(sar_path) |
|
|
|
|
| with col2: |
| st.subheader("Optical Image (Ground Truth)") |
| |
| opt_path = os.path.join("samples/OPTICAL", selected_sample) |
| if os.path.exists(opt_path): |
| display_image(opt_path) |
| else: |
| st.info("No matching optical image found") |
| |
| |
| if st.button("Generate Optical-like Image", key="process_optical_sample_btn"): |
| with st.spinner("Processing sample image..."): |
| try: |
| |
| with open(sar_path, 'rb') as f: |
| file_bytes = f.read() |
| |
| sar_batch, sar_image = load_sar_image(io.BytesIO(file_bytes)) |
| |
| if sar_batch is not None: |
| |
| seg_mask, colorized = process_image( |
| sar_batch, |
| st.session_state['unet_model'], |
| st.session_state.get('generator_model') |
| ) |
| |
| |
| sar_rgb, colored_pred, overlay, colorized_img = visualize_results( |
| sar_image, seg_mask, colorized |
| ) |
| |
| |
| has_gt = os.path.exists(opt_path) |
| |
| if has_gt: |
| |
| try: |
| |
| with rasterio.open(opt_path) as src: |
| gt_image = src.read() |
| |
| if gt_image.shape[0] == 3: |
| gt_image = np.transpose(gt_image, (1, 2, 0)) |
| else: |
| gt_image = src.read(1) |
| |
| |
| |
| gt_image = np.expand_dims(gt_image, axis=-1) |
| gt_image = np.repeat(gt_image, 3, axis=-1) |
| except Exception as rasterio_error: |
| try: |
| |
| gt_image = np.array(Image.open(opt_path).convert('RGB')) |
| except Exception as pil_error: |
| st.error(f"Both rasterio and PIL failed to load the ground truth: {str(pil_error)}") |
| raise |
|
|
| |
| if gt_image.shape[:2] != (256, 256): |
| gt_image = cv2.resize(gt_image, (256, 256)) |
|
|
| |
| if gt_image.dtype != np.uint8: |
| if np.max(gt_image) > 1.0 and np.max(gt_image) <= 255: |
| gt_image = gt_image.astype(np.uint8) |
| elif np.max(gt_image) <= 1.0: |
| gt_image = (gt_image * 255).astype(np.uint8) |
| else: |
| |
| gt_min, gt_max = np.min(gt_image), np.max(gt_image) |
| gt_image = ((gt_image - gt_min) / (gt_max - gt_min) * 255).astype(np.uint8) |
| |
| |
| fig, axes = plt.subplots(1, 4, figsize=(16, 4)) |
| |
| |
| axes[0].imshow(sar_rgb, cmap='gray') |
| axes[0].set_title('Original SAR', color='white') |
| axes[0].axis('off') |
| |
| |
| axes[1].imshow(gt_image) |
| axes[1].set_title('Ground Truth', color='white') |
| axes[1].axis('off') |
| |
| |
| axes[2].imshow(colored_pred) |
| axes[2].set_title('Segmentation', color='white') |
| axes[2].axis('off') |
| |
| |
| if colorized_img is not None: |
| |
| colorized_display = (colorized_img * 0.5) + 0.5 |
| axes[3].imshow(colorized_display) |
| else: |
| axes[3].imshow(overlay) |
| axes[3].set_title('Generated Image', color='white') |
| axes[3].axis('off') |
| |
| |
| fig.patch.set_facecolor('#0a0a1f') |
| for ax in axes: |
| ax.set_facecolor('#0a0a1f') |
| |
| plt.tight_layout() |
| |
| |
| st.pyplot(fig) |
| |
| |
| if colorized_img is not None: |
| |
| colorized_norm = (colorized_img * 0.5) + 0.5 |
| gt_norm = gt_image.astype(np.float32) / 255.0 |
| |
| |
| mse = np.mean((colorized_norm - gt_norm) ** 2) |
| psnr = 20 * np.log10(1.0 / np.sqrt(mse)) |
| |
| |
| from skimage.metrics import structural_similarity as ssim |
|
|
| try: |
| |
| min_dim = min(colorized_norm.shape[0], colorized_norm.shape[1]) |
| win_size = min(7, min_dim - (min_dim % 2) + 1) |
| |
| ssim_value = ssim( |
| colorized_norm, |
| gt_norm, |
| win_size=win_size, |
| channel_axis=2, |
| data_range=1.0 |
| ) |
| except Exception as e: |
| st.warning(f"Could not calculate SSIM: {str(e)}") |
| ssim_value = 0.0 |
| |
| |
| col1, col2 = st.columns(2) |
| with col1: |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| st.markdown(f"<p class='metric-value'>{psnr:.2f}</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>PSNR (dB)</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| |
| with col2: |
| st.markdown("<div class='metric-card'>", unsafe_allow_html=True) |
| st.markdown(f"<p class='metric-value'>{ssim_value:.4f}</p>", unsafe_allow_html=True) |
| st.markdown("<p class='metric-label'>SSIM</p>", unsafe_allow_html=True) |
| st.markdown("</div>", unsafe_allow_html=True) |
| else: |
| |
| col1, col2, col3 = st.columns(3) |
| |
| with col1: |
| st.subheader("Original SAR Image") |
| st.image(sar_rgb, use_container_width=True) |
| |
| with col2: |
| st.subheader("Predicted Segmentation") |
| st.image(colored_pred, use_container_width=True) |
| |
| with col3: |
| st.subheader("Colorized SAR") |
| if colorized_img is not None: |
| |
| colorized_display = (colorized_img * 0.5) + 0.5 |
| st.image(colorized_display, use_container_width=True) |
| else: |
| st.image(overlay, use_container_width=True) |
| |
| |
| col1, col2 = st.columns(2) |
| |
| with col1: |
| |
| seg_buf = io.BytesIO() |
| plt.imsave(seg_buf, colored_pred, format='png') |
| seg_buf.seek(0) |
| |
| st.download_button( |
| label="Download Segmentation", |
| data=seg_buf, |
| file_name=f"sample_segmentation_{selected_sample}.png", |
| mime="image/png", |
| key="download_sample_seg" |
| ) |
| |
| with col2: |
| |
| gen_buf = io.BytesIO() |
| if colorized_img is not None: |
| plt.imsave(gen_buf, (colorized_img * 0.5) + 0.5, format='png') |
| else: |
| plt.imsave(gen_buf, overlay, format='png') |
| gen_buf.seek(0) |
| |
| st.download_button( |
| label="Download Optical-like Image", |
| data=gen_buf, |
| file_name=f"sample_optical_{selected_sample}.png", |
| mime="image/png", |
| key="download_sample_optical" |
| ) |
| except Exception as e: |
| st.error(f"Error processing sample image: {str(e)}") |
| elif not sample_files: |
| st.info("No sample images found. Please add some images to the 'samples/SAR' directory.") |
| else: |
| st.warning("Please load the models from the sidebar first.") |
| |
|
|
|
|
| |
| st.markdown("</div>", unsafe_allow_html=True) |
|
|
| |
| st.markdown(""" |
| <div style="text-align: center; margin-top: 2rem; padding: 1rem; background-color: rgba(0, 0, 0, 0.3); border-radius: 0.5rem;"> |
| <p style="color: #bfdbfe; font-size: 0.9rem;"> |
| SAR IMAGE PROCESSING | VARUN & MOKSHYAGNA |
| </p> |
| </div> |
| """, unsafe_allow_html=True) |
| |
|
|
| def create_stars_html(): |
| """Create twinkling stars effect for background""" |
| stars_html = """ |
| <div class="stars"> |
| """ |
| for i in range(100): |
| size = random.uniform(1, 3) |
| top = random.uniform(0, 100) |
| left = random.uniform(0, 100) |
| duration = random.uniform(3, 8) |
| opacity = random.uniform(0.2, 0.8) |
| |
| stars_html += f""" |
| <div class="star" style=" |
| width: {size}px; |
| height: {size}px; |
| top: {top}%; |
| left: {left}%; |
| --duration: {duration}s; |
| --opacity: {opacity}; |
| "></div> |
| """ |
| stars_html += "</div>" |
| return stars_html |
|
|
| |
| |
| def setup_page_style(): |
| """Set up the page style with CSS based on selected theme""" |
| |
| |
| common_css = """ |
| /* Create twinkling stars effect */ |
| @keyframes twinkle { |
| 0%, 100% { opacity: 0.2; } |
| 50% { opacity: 1; } |
| } |
| |
| .stars { |
| position: fixed; |
| top: 0; |
| left: 0; |
| width: 100%; |
| height: 100%; |
| pointer-events: none; |
| z-index: -1; |
| } |
| |
| .star { |
| position: absolute; |
| background-color: white; |
| border-radius: 50%; |
| animation: twinkle var(--duration) infinite; |
| opacity: var(--opacity); |
| } |
| |
| /* Tab styling */ |
| .stTabs [data-baseweb="tab-list"] { |
| gap: 24px !important; |
| border-radius: 0.5rem; |
| padding: 0.8rem; |
| margin-bottom: 3rem !important; |
| display: flex; |
| justify-content: center !important; |
| width: 100%; |
| } |
| |
| .stTabs [data-baseweb="tab"] { |
| height: 5rem !important; |
| white-space: pre-wrap; |
| border-radius: 0.5rem; |
| font-weight: 600 !important; |
| font-size: 1.6rem !important; |
| padding: 0 25px !important; |
| display: flex; |
| align-items: center; |
| justify-content: center; |
| min-width: 200px !important; |
| } |
| |
| /* Add more space between tab panels */ |
| .stTabs [data-baseweb="tab-panel"] { |
| padding-top: 3rem !important; |
| padding-bottom: 3rem !important; |
| } |
| |
| /* Button styling */ |
| .stButton>button { |
| border: none; |
| border-radius: 0.5rem; |
| padding: 0.8rem 1.5rem !important; |
| font-weight: 500; |
| font-size: 1.2rem !important; |
| margin-top: 1.5rem !important; |
| margin-bottom: 1.5rem !important; |
| } |
| |
| /* Spacing */ |
| .element-container { |
| margin-bottom: 2.5rem !important; |
| } |
| |
| h3 { |
| margin-top: 3rem !important; |
| margin-bottom: 2rem !important; |
| font-size: 1.8rem !important; |
| } |
| |
| h4 { |
| margin-top: 2.5rem !important; |
| margin-bottom: 1.5rem !important; |
| font-size: 1.5rem !important; |
| } |
| |
| h5 { |
| margin-top: 2rem !important; |
| margin-bottom: 1.5rem !important; |
| font-size: 1.3rem !important; |
| } |
| |
| img { |
| margin-top: 1.5rem !important; |
| margin-bottom: 2.5rem !important; |
| } |
| |
| .stProgress > div { |
| margin-top: 2rem !important; |
| margin-bottom: 2rem !important; |
| } |
| |
| .stSlider { |
| padding-top: 1.5rem !important; |
| padding-bottom: 2.5rem !important; |
| } |
| |
| .row-widget { |
| margin-top: 1.5rem !important; |
| margin-bottom: 2.5rem !important; |
| } |
| """ |
| |
| |
| dark_css = """ |
| .stApp { |
| background-color: #0a0a1f; |
| color: white; |
| } |
| |
| .main { |
| background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80"); |
| background-size: cover; |
| background-position: center; |
| background-repeat: no-repeat; |
| background-attachment: fixed; |
| position: relative; |
| } |
| |
| .main::before { |
| content: ""; |
| position: absolute; |
| top: 0; |
| left: 0; |
| width: 100%; |
| height: 100%; |
| background-color: rgba(10, 10, 31, 0.7); |
| backdrop-filter: blur(5px); |
| z-index: -1; |
| } |
| |
| /* Title styling */ |
| h1.title { |
| background: linear-gradient(to right, #a78bfa, #ec4899, #3b82f6); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| background-clip: text; |
| color: transparent; |
| font-size: 3rem !important; |
| font-weight: bold !important; |
| text-align: center !important; |
| margin-bottom: 0.5rem !important; |
| display: block !important; |
| position: relative !important; |
| z-index: 10 !important; |
| } |
| |
| p.subtitle { |
| color: #bfdbfe !important; |
| font-size: 1.2rem !important; |
| text-align: center !important; |
| margin-bottom: 2rem !important; |
| position: relative !important; |
| z-index: 10 !important; |
| } |
| |
| /* Tab styling */ |
| .stTabs [data-baseweb="tab-list"] { |
| background-color: rgba(0, 0, 0, 0.3); |
| } |
| |
| .stTabs [data-baseweb="tab"] { |
| background-color: transparent; |
| color: white; |
| } |
| |
| .stTabs [aria-selected="true"] { |
| background-color: rgba(147, 51, 234, 0.5) !important; |
| transform: scale(1.05); |
| transition: all 0.2s ease; |
| } |
| |
| /* Card and box styling */ |
| .upload-box { |
| border: 2px dashed rgba(147, 51, 234, 0.5); |
| border-radius: 1rem; |
| padding: 4rem !important; |
| text-align: center; |
| margin-bottom: 3rem !important; |
| } |
| |
| .card { |
| background-color: rgba(0, 0, 0, 0.3); |
| border: 1px solid rgba(147, 51, 234, 0.3); |
| border-radius: 1rem; |
| padding: 2.5rem !important; |
| backdrop-filter: blur(10px); |
| margin-bottom: 3rem !important; |
| } |
| |
| /* Button styling */ |
| .stButton>button { |
| background: linear-gradient(to right, #7c3aed, #2563eb); |
| color: white; |
| } |
| |
| .stButton>button:hover { |
| background: linear-gradient(to right, #6d28d9, #1d4ed8); |
| } |
| |
| .download-btn { |
| background-color: #2563eb !important; |
| } |
| |
| .stSlider>div>div>div { |
| background-color: #7c3aed; |
| } |
| |
| /* Metrics styling */ |
| .plot-container { |
| background-color: rgba(0, 0, 0, 0.3); |
| border-radius: 1rem; |
| padding: 2rem !important; |
| margin-bottom: 3rem !important; |
| } |
| |
| .metric-card { |
| background-color: rgba(0, 0, 0, 0.3); |
| border: 1px solid rgba(147, 51, 234, 0.3); |
| border-radius: 0.5rem; |
| padding: 1.5rem !important; |
| text-align: center; |
| margin-bottom: 2rem !important; |
| } |
| |
| .metric-value { |
| font-size: 2rem !important; |
| font-weight: bold; |
| color: #a78bfa; |
| } |
| |
| .metric-label { |
| font-size: 1.1rem !important; |
| color: #bfdbfe; |
| } |
| |
| /* Form elements */ |
| .stFileUploader > div { |
| background-color: rgba(0, 0, 0, 0.3) !important; |
| border: 1px dashed rgba(147, 51, 234, 0.5) !important; |
| padding: 2rem !important; |
| margin-bottom: 2rem !important; |
| } |
| |
| .stSelectbox > div > div { |
| background-color: rgba(0, 0, 0, 0.3) !important; |
| border: 1px solid rgba(147, 51, 234, 0.3) !important; |
| } |
| """ |
| |
| |
| |
| |
| light_css = """ |
| /* Keep the same dark background */ |
| .stApp { |
| background-color: #0a0a1f; |
| } |
| |
| .main { |
| background-image: url("https://images.unsplash.com/photo-1451187580459-43490279c0fa?ixlib=rb-1.2.1&auto=format&fit=crop&w=1352&q=80"); |
| background-size: cover; |
| background-position: center; |
| background-repeat: no-repeat; |
| background-attachment: fixed; |
| position: relative; |
| } |
| |
| .main::before { |
| content: ""; |
| position: absolute; |
| top: 0; |
| left: 0; |
| width: 100%; |
| height: 100%; |
| background-color: rgba(10, 10, 31, 0.7); |
| backdrop-filter: blur(5px); |
| z-index: -1; |
| } |
| |
| /* Make all text white/light */ |
| p, span, label, div, h1, h2, h3, h4, h5, h6, li { |
| color: white !important; |
| } |
| |
| /* Title styling - brighter gradient for better visibility */ |
| h1.title { |
| background: linear-gradient(to right, #d8b4fe, #f9a8d4, #93c5fd); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| background-clip: text; |
| color: transparent; |
| font-size: 3rem !important; |
| font-weight: bold !important; |
| text-align: center !important; |
| margin-bottom: 0.5rem !important; |
| display: block !important; |
| position: relative !important; |
| z-index: 10 !important; |
| } |
| |
| p.subtitle { |
| color: #e0e7ff !important; /* Lighter purple */ |
| font-size: 1.2rem !important; |
| text-align: center !important; |
| margin-bottom: 2rem !important; |
| position: relative !important; |
| z-index: 10 !important; |
| } |
| |
| /* Tab styling - brighter for better visibility */ |
| .stTabs [data-baseweb="tab-list"] { |
| background-color: rgba(0, 0, 0, 0.3); |
| } |
| |
| .stTabs [data-baseweb="tab"] { |
| background-color: transparent; |
| color: white !important; |
| } |
| |
| .stTabs [aria-selected="true"] { |
| background-color: rgba(167, 139, 250, 0.5) !important; /* Brighter purple */ |
| transform: scale(1.05); |
| transition: all 0.2s ease; |
| } |
| |
| /* Card and box styling - brighter borders */ |
| .upload-box { |
| border: 2px dashed rgba(167, 139, 250, 0.7); /* Brighter purple */ |
| border-radius: 1rem; |
| padding: 4rem !important; |
| text-align: center; |
| margin-bottom: 3rem !important; |
| } |
| |
| .card { |
| background-color: rgba(0, 0, 0, 0.3); |
| border: 1px solid rgba(167, 139, 250, 0.5); /* Brighter purple */ |
| border-radius: 1rem; |
| padding: 2.5rem !important; |
| backdrop-filter: blur(10px); |
| margin-bottom: 3rem !important; |
| } |
| |
| /* Button styling - brighter gradient */ |
| .stButton>button { |
| background: linear-gradient(to right, #a78bfa, #60a5fa); |
| color: white; |
| } |
| |
| .stButton>button:hover { |
| background: linear-gradient(to right, #8b5cf6, #3b82f6); |
| } |
| |
| .download-btn { |
| background-color: #60a5fa !important; |
| } |
| |
| .stSlider>div>div>div { |
| background-color: #a78bfa; |
| } |
| |
| /* Metrics styling - brighter accents */ |
| .plot-container { |
| background-color: rgba(0, 0, 0, 0.3); |
| border-radius: 1rem; |
| padding: 2rem !important; |
| margin-bottom: 3rem !important; |
| } |
| |
| .metric-card { |
| background-color: rgba(0, 0, 0, 0.3); |
| border: 1px solid rgba(167, 139, 250, 0.5); /* Brighter purple */ |
| border-radius: 0.5rem; |
| padding: 1.5rem !important; |
| text-align: center; |
| margin-bottom: 2rem !important; |
| } |
| |
| .metric-value { |
| font-size: 2rem !important; |
| font-weight: bold; |
| color: #d8b4fe; /* Brighter purple */ |
| } |
| |
| .metric-label { |
| font-size: 1.1rem !important; |
| color: #e0e7ff; /* Lighter purple */ |
| } |
| |
| /* Form elements - brighter borders */ |
| .stFileUploader > div { |
| background-color: rgba(0, 0, 0, 0.3) !important; |
| border: 1px dashed rgba(167, 139, 250, 0.7) !important; /* Brighter purple */ |
| padding: 2rem !important; |
| margin-bottom: 2rem !important; |
| } |
| |
| .stSelectbox > div > div { |
| background-color: rgba(0, 0, 0, 0.3) !important; |
| border: 1px solid rgba(167, 139, 250, 0.5) !important; /* Brighter purple */ |
| } |
| |
| /* Make sure all text inputs have white text */ |
| input, textarea { |
| color: white !important; |
| } |
| |
| /* Ensure sidebar text is white */ |
| .css-1d391kg, .css-1lcbmhc { |
| color: white !important; |
| } |
| |
| /* Make sure plot text is visible on dark background */ |
| .js-plotly-plot .plotly .main-svg text { |
| fill: white !important; |
| } |
| |
| /* Keep stars visible in light theme */ |
| .star { |
| background-color: white; |
| opacity: 0.8; |
| } |
| |
| /* Make sure all streamlit elements have white text */ |
| .stMarkdown, .stText, .stCode, .stTextInput, .stTextArea, .stSelectbox, .stMultiselect, |
| .stSlider, .stCheckbox, .stRadio, .stNumber, .stDate, .stTime, .stDateInput, .stTimeInput { |
| color: white !important; |
| } |
| |
| /* Ensure dropdown options are visible */ |
| .stSelectbox ul li { |
| color: black !important; |
| } |
| """ |
|
|
| |
| |
| |
| if st.session_state.theme == "dark": |
| st.markdown(f"<style>{common_css}{dark_css}</style>", unsafe_allow_html=True) |
| else: |
| st.markdown(f"<style>{common_css}{light_css}</style>", unsafe_allow_html=True) |
|
|
| |
|
|
| if __name__ == "__main__": |
| |
| setup_page_style() |
| |
| |
| setup_gpu() |
| |
| |
| if 'app_mode' not in st.session_state: |
| st.session_state.app_mode = "SAR Colorization" |
| if 'model_loaded' not in st.session_state: |
| st.session_state.model_loaded = False |
| if 'segmentation' not in st.session_state: |
| st.session_state.segmentation = SARSegmentation(img_rows=256, img_cols=256) |
| if 'processed_images' not in st.session_state: |
| st.session_state.processed_images = [] |
|
|
|
|
|
|
|
|