GrainSight / app.py
fazzam's picture
Update app.py
5210acf verified
import streamlit as st
from PIL import Image
from src.ui.drawable_canvas import drawable_canvas
from src.ui.streamlit_ui import streamlit_ui
from src.segmentation import segment_everything
from src.utils import calculate_parameters, plot_distribution, calculate_pixel_length, plot_cumulative_frequency
from ultralytics import YOLO
import torch
import cv2
# Cache the model and device
@st.cache_data()
def load_model_and_initialize():
model_path = "src/model/FastSAM-x.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = YOLO(model_path)
return model, device
def main():
"""Main application logic."""
uploaded_image, input_size, iou_threshold, conf_threshold, better_quality, contour_thickness, real_world_length, max_det = streamlit_ui()
if uploaded_image is not None:
try:
canvas_result = drawable_canvas(uploaded_image, input_size)
pixel_length = None
if canvas_result.json_data is not None and "objects" in canvas_result.json_data:
if len(canvas_result.json_data["objects"]) > 0:
line_object = canvas_result.json_data["objects"][0]
start_point = [line_object['x1'], line_object['y1']]
end_point = [line_object['x2'], line_object['y2']]
# Get image dimensions for calculating the scaling factor
image_width, image_height = Image.open(uploaded_image).size
scale_factor = input_size / max(image_width, image_height)
# Calculate pixel length with the scaling factor
pixel_length = calculate_pixel_length(start_point, end_point)
st.write(f"Pixel length of the line: {pixel_length}")
else:
st.write("Please draw a line to set the scale or enter the real-world length.")
else:
st.write("Please draw a line to set the scale or enter the real-world length.")
if pixel_length is not None and real_world_length is not None:
scale_factor = real_world_length / pixel_length
else:
st.write("Scale factor could not be calculated. Make sure to draw a line and enter the real-world length.")
return
input_image = Image.open(uploaded_image)
# Load the model and device from cache
model, device = load_model_and_initialize()
segmented_image, annotations = segment_everything(
input_image,
model=model,
device=device,
input_size=input_size,
iou_threshold=iou_threshold,
conf_threshold=conf_threshold,
better_quality=better_quality,
contour_thickness=contour_thickness,
max_det=max_det
)
st.image(segmented_image, caption="Segmented Image", use_column_width=True)
# Calculate and display object parameters
df = calculate_parameters(annotations, scale_factor)
if not df.empty:
st.write("Summary of Object Parameters:")
st.dataframe(df)
csv = df.to_csv(index=False)
st.download_button(
label="Download data as CSV",
data=csv,
file_name='grain_parameters.csv',
mime='text/csv',
)
plot_cumulative_frequency(df)
filtered_columns = [col for col in df.columns.tolist() if col != 'Object']
selected_parameter = st.selectbox("Select a parameter to see its distribution:", filtered_columns)
if selected_parameter:
plot_distribution(df, selected_parameter)
else:
st.write("No parameter selected for plotting.")
else:
st.write("No objects detected.")
except Exception as e:
st.error("An error occurred during processing. Please check the logs for details.")
else:
st.write("Please upload an image.")
if __name__ == "__main__":
main()