Spaces:
Running
Running
import streamlit as st | |
from PIL import Image | |
from src.ui.drawable_canvas import drawable_canvas | |
from src.ui.streamlit_ui import streamlit_ui | |
from src.segmentation import segment_everything | |
from src.utils import calculate_parameters, plot_distribution, calculate_pixel_length, plot_cumulative_frequency | |
from ultralytics import YOLO | |
import torch | |
import cv2 | |
# Cache the model and device | |
def load_model_and_initialize(): | |
model_path = "src/model/FastSAM-x.pt" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = YOLO(model_path) | |
return model, device | |
def main(): | |
"""Main application logic.""" | |
uploaded_image, input_size, iou_threshold, conf_threshold, better_quality, contour_thickness, real_world_length, max_det = streamlit_ui() | |
if uploaded_image is not None: | |
try: | |
canvas_result = drawable_canvas(uploaded_image, input_size) | |
pixel_length = None | |
if canvas_result.json_data is not None and "objects" in canvas_result.json_data: | |
if len(canvas_result.json_data["objects"]) > 0: | |
line_object = canvas_result.json_data["objects"][0] | |
start_point = [line_object['x1'], line_object['y1']] | |
end_point = [line_object['x2'], line_object['y2']] | |
# Get image dimensions for calculating the scaling factor | |
image_width, image_height = Image.open(uploaded_image).size | |
scale_factor = input_size / max(image_width, image_height) | |
# Calculate pixel length with the scaling factor | |
pixel_length = calculate_pixel_length(start_point, end_point) | |
st.write(f"Pixel length of the line: {pixel_length}") | |
else: | |
st.write("Please draw a line to set the scale or enter the real-world length.") | |
else: | |
st.write("Please draw a line to set the scale or enter the real-world length.") | |
if pixel_length is not None and real_world_length is not None: | |
scale_factor = real_world_length / pixel_length | |
else: | |
st.write("Scale factor could not be calculated. Make sure to draw a line and enter the real-world length.") | |
return | |
input_image = Image.open(uploaded_image) | |
# Load the model and device from cache | |
model, device = load_model_and_initialize() | |
segmented_image, annotations = segment_everything( | |
input_image, | |
model=model, | |
device=device, | |
input_size=input_size, | |
iou_threshold=iou_threshold, | |
conf_threshold=conf_threshold, | |
better_quality=better_quality, | |
contour_thickness=contour_thickness, | |
max_det=max_det | |
) | |
st.image(segmented_image, caption="Segmented Image", use_column_width=True) | |
# Calculate and display object parameters | |
df = calculate_parameters(annotations, scale_factor) | |
if not df.empty: | |
st.write("Summary of Object Parameters:") | |
st.dataframe(df) | |
csv = df.to_csv(index=False) | |
st.download_button( | |
label="Download data as CSV", | |
data=csv, | |
file_name='grain_parameters.csv', | |
mime='text/csv', | |
) | |
plot_cumulative_frequency(df) | |
filtered_columns = [col for col in df.columns.tolist() if col != 'Object'] | |
selected_parameter = st.selectbox("Select a parameter to see its distribution:", filtered_columns) | |
if selected_parameter: | |
plot_distribution(df, selected_parameter) | |
else: | |
st.write("No parameter selected for plotting.") | |
else: | |
st.write("No objects detected.") | |
except Exception as e: | |
st.error("An error occurred during processing. Please check the logs for details.") | |
else: | |
st.write("Please upload an image.") | |
if __name__ == "__main__": | |
main() | |