from dash import html, dcc, callback, Input, Output, State from dash.exceptions import PreventUpdate from typing import Tuple, Any, Dict, Optional from functools import partial from io import BytesIO from PIL import Image from pillow_heif import register_heif_opener from larvaecount.gradient import ( component_thesh, component_filter_thresh, contour_thresh ) from larvaecount.ui.ui_utils import ( get_cc_ui, get_cc_filter_ui, get_contour_ui, display_slider_value, get_results_container ) import plotly.express as px import base64 import dash import dash_bootstrap_components as dbc import numpy as np register_heif_opener() dash.register_page(__name__, path = "/") UPLOAD_HEIGHT = "25vh" COUNT_FUNCS = { "Gradient CC": get_cc_ui, "Gradient CC w/ Filter": get_cc_filter_ui, "Gradient Contour w/ Filter": get_contour_ui } DEFAULT_STRATEGY = "Gradient CC w/ Filter" def get_initial_upload_container() -> dbc.Container: return dcc.Upload( id = "upload-data", children = dbc.Container( children = [ html.Img( src = "assets/camera.png", alt = "camera-image", className = "h-50" ), html.H2("Drag and Drop or Select Image File") ], class_name = "w-100 d-flex flex-column justify-content-center align-items-center", style = { "height": UPLOAD_HEIGHT } ) ) def get_new_upload_container( image_b64: str, file_name: str ) -> dbc.Container: decoded_bytes = base64.b64decode(image_b64) image_data = BytesIO(decoded_bytes) pil_img = Image.open(image_data) img = np.array(pil_img) image_fig = px.imshow( img, height = 750, ) return dbc.Container( children = [ html.H3( children = file_name, className = "p-2 text-start", ), dcc.Graph( figure = image_fig, ), dbc.Container( children = dcc.Upload( children = dbc.Button( children = "Upload New Image", color = "secondary" ), id = "upload-data" ), class_name = "w-100 pb-4 d-flex flex-row justify-content-center align-items-center" ) ] ) layout = dbc.Container( children = [ dbc.Container( children = dcc.Loading( children = get_initial_upload_container(), id = "image-upload-container", color = "black" ), class_name = "m-0 p-0 border border-dark" ), dcc.Store( id = "img-data-store", storage_type = "memory" ), dbc.Modal( children = [ dbc.ModalHeader( dbc.ModalTitle("Error Processing Image File") ), dbc.ModalBody( children = "", id = "upload-modal-content" ), ], is_open = False, id = "upload-modal" ), html.H4("Select Counting Strategy", className = "text-start mt-3"), dcc.Dropdown( options = [name for name in COUNT_FUNCS], value = DEFAULT_STRATEGY, id = "strat-picker", className = "my-2 w-100" ), dbc.Container( id = "count-ui-container", className = "mt-1 mx-0 px-0" ), dcc.Loading( children = dbc.Container( id = "count-res-container", className = "mt-4 mx-0 px-0" ), type = "default", color = "black" ) ], class_name = "text-center mt-3" ) @callback( Output("image-upload-container", "children"), Output("img-data-store", "data"), Output("upload-modal-content", "children"), Output("upload-modal", "is_open"), Input("upload-data", "contents"), State("upload-data", "filename"), State("image-upload-container", "children"), State("img-data-store", "data") ) def on_image_upload( upload_image_data: str, upload_image_name: str, curr_upload_chidren: Any, curr_img_store_data: Dict, ) -> Tuple[dbc.Container, Dict, str, bool]: if not upload_image_data: raise PreventUpdate try: content_type, content_string = upload_image_data.split(',') next_children = get_new_upload_container(content_string, upload_image_name) return ( next_children, {"img": content_string}, "", False ) except Exception as e: return ( curr_upload_chidren, curr_img_store_data, str(e), True ) @callback( Output("count-ui-container", "children"), Input("strat-picker", "value") ) def on_select_strat( curr_strat: str ) -> Optional[dbc.Container]: if curr_strat not in COUNT_FUNCS: return None ui_fun = COUNT_FUNCS[curr_strat] return ui_fun() @callback( Output("count-res-container", "children", allow_duplicate = True), Input("count-cc", "n_clicks"), State("select-cc-color-thresh", "value"), State("select-cc-avg-area", "value"), State("select-cc-max-eggs", "value"), State("img-data-store", "data"), allow_duplicate = True, prevent_initial_call = True ) def on_count_cc( n_clicks: int, color_thresh: int, avg_area: int, max_eggs: Optional[int], image_store: Dict, ) -> dbc.Container: if not n_clicks: return None decoded_bytes = base64.b64decode(image_store["img"]) image_data = BytesIO(decoded_bytes) pil_img = Image.open(image_data) img = np.array(pil_img) color_thresh = int(color_thresh) avg_area = int(avg_area) if max_eggs: max_eggs = int(max_eggs) results = component_thesh( img, color_thresh = color_thresh, avg_area = avg_area, max_eggs = max_eggs ) return get_results_container(results) @callback( Output("count-res-container", "children", allow_duplicate = True), Input("count-cc-filter", "n_clicks"), State("select-cc-filter-color-thresh", "value"), State("select-cc-filter-avg-area", "value"), State("select-cc-filter-max-eggs", "value"), State("select-cc-kernel-width", "value"), State("select-cc-kernel-height", "value"), State("img-data-store", "data"), prevent_initial_call = True ) def on_count_cc( n_clicks: int, color_thresh: int, avg_area: int, max_eggs: Optional[int], kernel_width: int, kernel_height: int, image_store: Dict, ) -> dbc.Container: if not n_clicks: return None decoded_bytes = base64.b64decode(image_store["img"]) image_data = BytesIO(decoded_bytes) pil_img = Image.open(image_data) img = np.array(pil_img) color_thresh = int(color_thresh) avg_area = int(avg_area) kernel_width = int(kernel_width) kernel_height = int(kernel_height) if max_eggs: max_eggs = int(max_eggs) results = component_filter_thresh( img, color_thresh = color_thresh, avg_area = avg_area, kernal_size = (kernel_width, kernel_height), max_eggs = max_eggs ) return get_results_container(results) @callback( Output("count-res-container", "children", allow_duplicate = True), Input("count-contour", "n_clicks"), State("select-contour-color-thresh", "value"), State("select-contour-avg-area", "value"), State("select-contour-max-eggs", "value"), State("select-contour-width", "value"), State("select-contour-height", "value"), State("img-data-store", "data"), prevent_initial_call = True ) def on_count_contour( n_clicks: int, color_thresh: int, avg_area: int, max_eggs: Optional[int], kernel_width: int, kernel_height: int, image_store: Dict, ) -> dbc.Container: if not n_clicks: return None decoded_bytes = base64.b64decode(image_store["img"]) image_data = BytesIO(decoded_bytes) pil_img = Image.open(image_data) img = np.array(pil_img) color_thresh = int(color_thresh) avg_area = int(avg_area) kernel_width = int(kernel_width) kernel_height = int(kernel_height) if max_eggs: max_eggs = int(max_eggs) results = contour_thresh( img, color_thresh = color_thresh, avg_area = avg_area, kernal_size = (kernel_width, kernel_height), max_eggs = max_eggs ) return get_results_container(results) callback( Output("display-cc-color-thresh", "children"), Input("select-cc-color-thresh", "value") )(partial(display_slider_value, "Color Threshold")) callback( Output("display-cc-filter-color-thresh", "children"), Input("select-cc-filter-color-thresh", "value") )(partial(display_slider_value, "Color Threshold")) callback( Output("display-contour-color-thresh", "children"), Input("select-contour-color-thresh", "value") )(partial(display_slider_value, "Color Threshold"))