|
|
|
|
|
import streamlit as st |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from src.models import perform_custom_segmentation |
|
|
from src.utils import resize_image, download_image |
|
|
import os |
|
|
import torch |
|
|
|
|
|
|
|
|
TARGET_SIZE = (750, 750) |
|
|
|
|
|
def get_parameters_from_sidebar() -> dict: |
|
|
"""Get segmentation parameters from sidebar""" |
|
|
st.sidebar.header("Segmentation Parameters") |
|
|
param_names = ['train_epoch', 'mod_dim1', 'mod_dim2', 'min_label_num', 'max_label_num'] |
|
|
param_values = [(1, 200, 43), (1, 128, 67), (1, 128, 63), (1, 20, 3), (1, 200, 25)] |
|
|
params = {name: st.sidebar.slider(name.replace('_', ' ').title(), *values) for name, values in zip(param_names, param_values)} |
|
|
|
|
|
|
|
|
target_size_width = st.sidebar.number_input("Target Size Width", 100, 1200, 750) |
|
|
target_size_height = st.sidebar.number_input("Target Size Height", 100, 1200, 750) |
|
|
params['target_size'] = (target_size_width, target_size_height) |
|
|
|
|
|
return params |
|
|
def display_segmentation_results() -> None: |
|
|
"""Display segmentation results""" |
|
|
st.image(st.session_state.segmented_image, caption='Updated Segmented Image', use_column_width=True) |
|
|
|
|
|
def randomize_colors() -> None: |
|
|
"""Randomize colors for segmentation labels""" |
|
|
unique_labels = np.unique(st.session_state.segmented_image.reshape(-1, 3), axis=0) |
|
|
random_colors = {tuple(label): tuple(np.random.randint(0, 256, size=3)) for label in unique_labels} |
|
|
|
|
|
for old_color, new_color in random_colors.items(): |
|
|
mask = np.all(st.session_state.segmented_image == np.array(old_color), axis=-1) |
|
|
st.session_state.segmented_image[mask] = new_color |
|
|
|
|
|
|
|
|
st.session_state.new_colors.update(random_colors) |
|
|
st.session_state.image_update_trigger += 1 |
|
|
|
|
|
def handle_color_picking() -> None: |
|
|
"""Handle color picking and other functionalities""" |
|
|
unique_labels = np.unique(st.session_state.segmented_image.reshape(-1, 3), axis=0) |
|
|
for i, label in enumerate(unique_labels): |
|
|
hex_label = f'#{label[0]:02x}{label[1]:02x}{label[2]:02x}' |
|
|
new_color = st.color_picker(f"Choose a new color for label {i}", value=hex_label, key=f"label_{i}") |
|
|
new_color_rgb = tuple(int(new_color.lstrip('#')[j:j+2], 16) for j in (0, 2, 4)) |
|
|
st.session_state.new_colors[tuple(label)] = new_color_rgb |
|
|
|
|
|
|
|
|
new_colors_hex = {tuple(label): f'#{label[0]:02x}{label[1]:02x}{label[2]:02x}' for label in st.session_state.new_colors.values()} |
|
|
|
|
|
for old_color, new_color in st.session_state.new_colors.items(): |
|
|
|
|
|
old_color_hex = f'#{old_color[0]:02x}{old_color[1]:02x}{old_color[2]:02x}' |
|
|
|
|
|
new_color_hex = new_colors_hex[new_color] |
|
|
|
|
|
mask = np.all(st.session_state.segmented_image == np.array(old_color), axis=-1) |
|
|
st.session_state.segmented_image[mask] = new_color |
|
|
|
|
|
|
|
|
st.session_state.image_update_trigger += 1 |
|
|
|
|
|
def calculate_and_display_label_percentages() -> None: |
|
|
"""Calculate and display label percentages""" |
|
|
final_labels = cv2.cvtColor(st.session_state.segmented_image, cv2.COLOR_BGR2GRAY) |
|
|
unique_labels, counts = np.unique(final_labels, return_counts=True) |
|
|
total_pixels = np.sum(counts) |
|
|
label_percentages = {int(label): (count / total_pixels) * 100 for label, count in zip(unique_labels, counts)} |
|
|
|
|
|
|
|
|
label_to_color = {} |
|
|
for label in unique_labels: |
|
|
mask = final_labels == label |
|
|
corresponding_color = st.session_state.segmented_image[mask][0] |
|
|
hex_color = f'#{corresponding_color[0]:02x}{corresponding_color[1]:02x}{corresponding_color[2]:02x}' |
|
|
label_to_color[int(label)] = hex_color |
|
|
|
|
|
st.write("Label Percentages:") |
|
|
for label, percentage in label_percentages.items(): |
|
|
hex_color = label_to_color[label] |
|
|
color_box = f'<div style="display: inline-block; width: 20px; height: 20px; background-color: {hex_color}; margin-right: 10px;"></div>' |
|
|
st.markdown(f'{color_box} Label {label}: {percentage:.2f}%', unsafe_allow_html=True) |
|
|
|
|
|
def main() -> None: |
|
|
st.title("PetroSeg") |
|
|
st.info(""" |
|
|
- **Training Epochs**: Higher values will lead to fewer segments but may take more time. |
|
|
- **Image Size**: For better efficiency, upload small-sized images. |
|
|
- **Cache**: For best results, clear the cache between different image uploads. You can do this from the menu in the top-right corner. |
|
|
""") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
|
|
|
|
|
|
|
|
if 'segmented_image' not in st.session_state: |
|
|
st.session_state.segmented_image = None |
|
|
if 'new_colors' not in st.session_state: |
|
|
st.session_state.new_colors = {} |
|
|
if 'image_update_trigger' not in st.session_state: |
|
|
st.session_state.image_update_trigger = 0 |
|
|
|
|
|
|
|
|
params = get_parameters_from_sidebar() |
|
|
|
|
|
uploaded_image = st.sidebar.file_uploader("Upload an image", type=["jpg", "png", "jpeg", "bmp", "tiff", "webp"]) |
|
|
if uploaded_image: |
|
|
file_bytes = np.asarray(bytearray(uploaded_image.read()), dtype=np.uint8) |
|
|
image = cv2.imdecode(file_bytes, 1) |
|
|
|
|
|
if image is None: |
|
|
st.error("Error loading image. Please check the file and try again.") |
|
|
return |
|
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
st.image(image_rgb, caption='Original Image', use_column_width=True) |
|
|
|
|
|
|
|
|
target_size = params['target_size'] |
|
|
image_resized = resize_image(image_rgb, target_size) |
|
|
|
|
|
if st.sidebar.button("Start Segmentation"): |
|
|
st.session_state.segmented_image = perform_custom_segmentation(image_resized, params) |
|
|
|
|
|
if st.sidebar.button("Change Colors"): |
|
|
randomize_colors() |
|
|
|
|
|
if st.session_state.segmented_image is not None: |
|
|
handle_color_picking() |
|
|
display_segmentation_results() |
|
|
calculate_and_display_label_percentages() |
|
|
download_image(st.session_state.segmented_image, 'segmented_image.png') |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|