Spaces:
Running
Running
| """Dash callbacks for AION Search.""" | |
| import json | |
| import time | |
| import logging | |
| import traceback | |
| import uuid | |
| import pandas as pd | |
| import dash | |
| from dash import Input, Output, State, callback_context, html, dcc | |
| import dash_bootstrap_components as dbc | |
| import src.config as config | |
| from src.config import ( | |
| DEFAULT_DISPLAY_COUNT, | |
| LOAD_MORE_COUNT, | |
| IMAGE_HEIGHT, | |
| IMAGE_WIDTH, | |
| ZILLIZ_PRIMARY_KEY, | |
| ) | |
| from src.components import create_vector_input_row | |
| from src.services import SearchService | |
| logger = logging.getLogger(__name__) | |
| def register_callbacks(app, search_service: SearchService): | |
| """Register all Dash callbacks with the app. | |
| Args: | |
| app: Dash app instance | |
| search_service: SearchService instance for performing searches | |
| """ | |
| def update_galaxy_count(_): | |
| """Update the galaxy count display.""" | |
| if search_service and config.TOTAL_GALAXIES > 0: | |
| return f"{config.TOTAL_GALAXIES:,} GALAXIES FROM LEGACY SURVEY DR10" | |
| else: | |
| return "loading..." | |
| def toggle_search_mode(n_clicks, advanced_style, basic_text): | |
| """Toggle between basic and advanced search interfaces.""" | |
| # Check if advanced mode is currently shown | |
| is_advanced_shown = advanced_style.get("display") == "block" | |
| if is_advanced_shown: | |
| # Switch to basic mode | |
| return {"display": "block"}, {"display": "none"}, "fas fa-chevron-down me-2", dash.no_update | |
| else: | |
| # Switch to advanced mode - copy text from basic to advanced | |
| return {"display": "none"}, {"display": "block"}, "fas fa-chevron-up me-2", basic_text or "" | |
| def delete_vector_input(n_clicks_list, current_children): | |
| """Handle deletion of vector input rows.""" | |
| if not n_clicks_list or not any(n_clicks_list): | |
| return dash.no_update | |
| ctx = callback_context | |
| if not ctx.triggered: | |
| return dash.no_update | |
| if ctx.triggered[0]["value"] is None or ctx.triggered[0]["value"] == 0: | |
| return dash.no_update | |
| button_id = ctx.triggered[0]["prop_id"] | |
| index_to_delete = json.loads(button_id.split(".")[0])["index"] | |
| logger.info(f"Delete button clicked for index: {index_to_delete}") | |
| # Filter out the row with the matching index | |
| new_children = [] | |
| for child in current_children: | |
| should_keep = True | |
| if isinstance(child, dict): | |
| if 'props' in child and 'id' in child['props']: | |
| child_id = child['props']['id'] | |
| if isinstance(child_id, dict) and child_id.get("type") == "vector-row" and child_id.get("index") == index_to_delete: | |
| should_keep = False | |
| elif hasattr(child, 'id') and isinstance(child.id, dict): | |
| if child.id.get("type") == "vector-row" and child.id.get("index") == index_to_delete: | |
| should_keep = False | |
| if should_keep: | |
| new_children.append(child) | |
| # Ensure at least one input remains | |
| if len(new_children) == 0: | |
| new_children = [create_vector_input_row(0)] | |
| return new_children | |
| def add_vector_input(n_clicks, current_children, count): | |
| """Add a new vector input row.""" | |
| if n_clicks: | |
| new_input = create_vector_input_row(count) | |
| current_children.append(new_input) | |
| return current_children, count + 1 | |
| return dash.no_update, dash.no_update | |
| def toggle_main_query_type(query_type): | |
| """Toggle visibility of main query inputs based on query type.""" | |
| if query_type == "image": | |
| return {"display": "none"}, {"display": "block"} | |
| else: | |
| return {"display": "block"}, {"display": "none"} | |
| def toggle_query_type_inputs(query_types): | |
| """Toggle visibility of text vs image inputs based on query type selection.""" | |
| text_styles = [] | |
| image_styles = [] | |
| for query_type in query_types: | |
| if query_type == "text": | |
| text_styles.append({"display": "block"}) | |
| image_styles.append({"display": "none"}) | |
| else: # image | |
| text_styles.append({"display": "none"}) | |
| image_styles.append({"display": "block"}) | |
| return text_styles, image_styles | |
| def trigger_search_from_examples(click1, click2, click3, click5, click6, click7, click8, current_clicks): | |
| """Trigger search when example buttons are clicked.""" | |
| ctx = callback_context | |
| if not ctx.triggered: | |
| return dash.no_update, dash.no_update, dash.no_update | |
| button_id = ctx.triggered[0]["prop_id"].split(".")[0] | |
| example_queries = { | |
| "example-1": "Two edge-on galaxies", | |
| "example-2": "A peculiar interacting galaxy system featuring plenty of tidal tails and a disturbed morphology", | |
| "example-3": "galaxy with stream", | |
| "example-5": "A violent merger in progress with visible tidal features", | |
| "example-6": "Low surface brightness", | |
| "example-7": "A face-on spiral with ring-like circular structure around a core", | |
| "example-8": "a bursty, star forming galaxy" | |
| } | |
| search_query = example_queries.get(button_id, "") | |
| if search_query: | |
| return (current_clicks or 0) + 1, search_query, search_query | |
| return dash.no_update, dash.no_update, dash.no_update | |
| def perform_search(n_clicks_basic, n_submit_basic, n_clicks_advanced, n_submit_advanced, n_submit_vector_texts, | |
| query_basic, query_advanced, | |
| rmag_range, advanced_style, main_operation, main_query_type, | |
| main_ra, main_dec, | |
| additional_query_types, additional_text_values, additional_ra_values, | |
| additional_dec_values, additional_operations): | |
| """Perform text search or advanced search based on mode.""" | |
| # Check which input triggered the callback | |
| ctx = callback_context | |
| if not ctx.triggered: | |
| return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update | |
| trigger_id = ctx.triggered[0]["prop_id"].split(".")[0] | |
| # Determine if advanced mode based on which button/input triggered the search | |
| is_advanced_mode = trigger_id in ["search-button-advanced", "search-input-advanced"] or advanced_style.get("display") == "block" | |
| # If advanced mode is active, perform advanced search | |
| if is_advanced_mode: | |
| return perform_advanced_search_logic( | |
| query_advanced, rmag_range, main_operation, main_query_type, | |
| main_ra, main_dec, | |
| additional_query_types, additional_text_values, | |
| additional_ra_values, additional_dec_values, | |
| additional_operations | |
| ) | |
| # Otherwise perform basic search | |
| query = query_basic | |
| if not query or not query.strip(): | |
| return "", dbc.Alert("Please enter a search query", color="warning"), None, True, True, None | |
| # Generate unique request_id for this search | |
| request_id = uuid.uuid4().hex | |
| try: | |
| # Extract min and max from slider range | |
| rmag_min, rmag_max = rmag_range if rmag_range else (None, None) | |
| start_time = time.time() | |
| df = search_service.search_text(query, rmag_min=rmag_min, rmag_max=rmag_max) | |
| search_time = time.time() - start_time | |
| # Log query to XML/CSV | |
| from src.utils import build_query_xml, log_query_to_csv | |
| query_xml = build_query_xml( | |
| text_queries=[query], | |
| text_weights=[1.0], | |
| rmag_min=rmag_min, | |
| rmag_max=rmag_max | |
| ) | |
| log_query_to_csv(query_xml, request_id=request_id) | |
| # Build results grid - only load first 60 images | |
| grid_items = build_galaxy_grid(df.head(DEFAULT_DISPLAY_COUNT)) | |
| # Prepare data for store | |
| search_data = prepare_search_data(df, query, request_id=request_id) | |
| # Create load more button | |
| load_more_button = create_load_more_button(len(df), DEFAULT_DISPLAY_COUNT) if len(df) > DEFAULT_DISPLAY_COUNT else None | |
| # Build filter description | |
| filter_desc = "" | |
| if rmag_min is not None and rmag_max is not None and (rmag_min != 13.0 or rmag_max != 20.0): | |
| filter_desc = f" + r-mag: [{rmag_min:.1f}, {rmag_max:.1f}]" | |
| # Build complete results container | |
| results_container = html.Div([ | |
| html.Div([ | |
| html.P(f"Top {len(df)} matching galaxies (showing {min(DEFAULT_DISPLAY_COUNT, len(df))})", | |
| className="results-header mb-2 text-center d-inline-block me-2"), | |
| dbc.Button( | |
| [html.I(className="fas fa-link me-1"), "Copy link"], | |
| id="copy-results-link", | |
| color="link", | |
| size="sm", | |
| n_clicks=0, | |
| className="info-button" | |
| ), | |
| html.Span(id="copy-results-feedback", style={"marginLeft": "8px", "color": "#28a745", "fontSize": "0.8rem"}) | |
| ], className="text-center"), | |
| html.P(f"'{query}'{filter_desc}", | |
| className="text-center mb-3", | |
| style={"color": "rgba(245, 245, 247, 0.6)", "font-size": "0.9rem"}), | |
| dbc.Row(grid_items, justify="center", id="search-results-grid"), | |
| load_more_button | |
| ]) | |
| # Store search params for URL generation | |
| search_params = { | |
| "text_queries": [query], | |
| "text_weights": [1.0], | |
| "image_queries": [], | |
| "image_weights": [], | |
| "rmag_min": rmag_min, | |
| "rmag_max": rmag_max | |
| } | |
| return "", results_container, search_data, False, False, search_params | |
| except Exception as e: | |
| error_msg = dbc.Alert(f"Search failed: {str(e)}", color="danger") | |
| logger.error(f"Search error: {e}") | |
| logger.error(f"Full traceback:\n{traceback.format_exc()}") | |
| # Log error | |
| from src.utils import build_query_xml, log_query_to_csv | |
| try: | |
| query_xml = build_query_xml( | |
| text_queries=[query], | |
| text_weights=[1.0], | |
| rmag_min=rmag_range[0] if rmag_range else None, | |
| rmag_max=rmag_range[1] if rmag_range else None | |
| ) | |
| log_query_to_csv( | |
| query_xml, | |
| request_id=request_id, | |
| error_occurred=True, | |
| error_message=str(e), | |
| error_type=type(e).__name__ | |
| ) | |
| except: | |
| pass | |
| return "", error_msg, None, True, True, None | |
| def toggle_modal(image_clicks, close_click, is_open, search_data): | |
| """Toggle galaxy detail modal.""" | |
| ctx = callback_context | |
| if not ctx.triggered: | |
| return False, "", "", "", None, "" | |
| if ctx.triggered[0]["prop_id"] == "close-modal.n_clicks": | |
| return False, "", "", "", None, "" | |
| if search_data: | |
| triggered_prop = ctx.triggered[0]["prop_id"] | |
| triggered_value = ctx.triggered[0]["value"] | |
| if triggered_value is None or triggered_value == 0: | |
| return False, "", "", "", None, "" | |
| if "galaxy-image" in triggered_prop: | |
| try: | |
| prop_dict = json.loads(triggered_prop.split(".n_clicks")[0]) | |
| clicked_idx = prop_dict["index"] | |
| if clicked_idx < len(search_data["ra"]): | |
| galaxy_info = extract_galaxy_info(search_data, clicked_idx) | |
| image_element, description_element = build_modal_content(galaxy_info) | |
| galaxy_data = { | |
| ZILLIZ_PRIMARY_KEY: galaxy_info[ZILLIZ_PRIMARY_KEY], | |
| "ra": galaxy_info["ra"], | |
| "dec": galaxy_info["dec"], | |
| "distance": galaxy_info["distance"], | |
| "r_mag": galaxy_info["r_mag"] | |
| } | |
| # Log click event | |
| from src.utils import log_click_event | |
| request_id = search_data.get("request_id") | |
| log_click_event( | |
| request_id=request_id, | |
| rank=clicked_idx, # 0-indexed rank | |
| primary_key=galaxy_info[ZILLIZ_PRIMARY_KEY], | |
| ra=galaxy_info["ra"], | |
| dec=galaxy_info["dec"], | |
| r_mag=galaxy_info["r_mag"], | |
| distance=galaxy_info["distance"] | |
| ) | |
| return ( | |
| True, | |
| f"Galaxy at RA={galaxy_info['ra']:.6f}, Dec={galaxy_info['dec']:.6f}", | |
| image_element, | |
| description_element, | |
| galaxy_data, | |
| "" # Clear copy feedback when opening new galaxy | |
| ) | |
| except: | |
| pass | |
| return False, "", "", "", None, "" | |
| def update_legacy_survey_link(galaxy_data): | |
| """Update Legacy Survey link with galaxy coordinates.""" | |
| if galaxy_data and "ra" in galaxy_data and "dec" in galaxy_data: | |
| ra = galaxy_data["ra"] | |
| dec = galaxy_data["dec"] | |
| return f"https://www.legacysurvey.org/viewer?ra={ra}&dec={dec}&layer=ls-dr10&zoom=16" | |
| return "#" | |
| def expand_galaxy_from_url(search_data, pending_galaxy_id): | |
| """Expand galaxy modal if there's a pending galaxy from URL state.""" | |
| if not pending_galaxy_id or not search_data: | |
| return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update | |
| # Find the galaxy by ID in search results | |
| primary_keys = search_data.get(ZILLIZ_PRIMARY_KEY, []) | |
| # Handle potential bytes encoding issue - try to match with and without b'' wrapper | |
| target_id = pending_galaxy_id | |
| # Clean up the ID if it has bytes notation | |
| if isinstance(target_id, str) and target_id.startswith("b'") and target_id.endswith("'"): | |
| target_id = target_id[2:-1] | |
| try: | |
| # Try to find exact match first | |
| if target_id in primary_keys: | |
| idx = primary_keys.index(target_id) | |
| else: | |
| # Try cleaning primary keys too | |
| cleaned_keys = [] | |
| for pk in primary_keys: | |
| if isinstance(pk, str) and pk.startswith("b'") and pk.endswith("'"): | |
| cleaned_keys.append(pk[2:-1]) | |
| else: | |
| cleaned_keys.append(pk) | |
| if target_id in cleaned_keys: | |
| idx = cleaned_keys.index(target_id) | |
| else: | |
| # Galaxy not found in results | |
| return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update, None, dash.no_update | |
| # Extract galaxy info and build modal content | |
| galaxy_info = extract_galaxy_info(search_data, idx) | |
| image_element, description_element = build_modal_content(galaxy_info) | |
| galaxy_data = { | |
| ZILLIZ_PRIMARY_KEY: galaxy_info[ZILLIZ_PRIMARY_KEY], | |
| "ra": galaxy_info["ra"], | |
| "dec": galaxy_info["dec"], | |
| "distance": galaxy_info["distance"], | |
| "r_mag": galaxy_info["r_mag"] | |
| } | |
| return ( | |
| True, # Open modal | |
| f"Galaxy at RA={galaxy_info['ra']:.6f}, Dec={galaxy_info['dec']:.6f}", | |
| image_element, | |
| description_element, | |
| galaxy_data, | |
| None, # Clear pending galaxy | |
| "" # Clear copy feedback when opening galaxy from URL | |
| ) | |
| except (ValueError, IndexError): | |
| # Galaxy not found, clear the pending state | |
| return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update, None, dash.no_update | |
| def toggle_info_modal(info_click, close_click, is_open): | |
| """Toggle info modal.""" | |
| ctx = callback_context | |
| if ctx.triggered: | |
| button_id = ctx.triggered[0]["prop_id"].split(".")[0] | |
| if button_id == "info-button": | |
| return True | |
| elif button_id == "close-info-modal": | |
| return False | |
| return is_open | |
| def load_more_galaxies(n_clicks, search_data): | |
| """Load more galaxies when the load more button is clicked.""" | |
| if n_clicks and search_data and "loaded_count" in search_data: | |
| current_count = search_data["loaded_count"] | |
| total_count = len(search_data["ra"]) | |
| next_count = min(current_count + LOAD_MORE_COUNT, total_count) | |
| # Build ALL grid items (existing + new) | |
| all_grid_items = [] | |
| for i in range(next_count): | |
| galaxy_info = extract_galaxy_info(search_data, i) | |
| grid_item = build_galaxy_card(galaxy_info, i) | |
| all_grid_items.append(grid_item) | |
| search_data["loaded_count"] = next_count | |
| load_more_button = create_load_more_button(total_count, next_count) if next_count < total_count else None | |
| results_container = html.Div([ | |
| html.Div([ | |
| html.P(f"Top {total_count} matching galaxies (showing {next_count})", | |
| className="results-header mb-2 text-center d-inline-block me-2"), | |
| dbc.Button( | |
| [html.I(className="fas fa-link me-1"), "Copy link"], | |
| id="copy-results-link", | |
| color="link", | |
| size="sm", | |
| n_clicks=0, | |
| className="info-button" | |
| ), | |
| html.Span(id="copy-results-feedback", style={"marginLeft": "8px", "color": "#28a745", "fontSize": "0.8rem"}) | |
| ], className="text-center"), | |
| html.P(f"'{search_data['query']}'", | |
| className="text-center mb-3", | |
| style={"color": "rgba(245, 245, 247, 0.6)", "font-size": "0.9rem"}), | |
| dbc.Row(all_grid_items, justify="center", id="search-results-grid"), | |
| load_more_button | |
| ]) | |
| return results_container, search_data | |
| return dash.no_update, dash.no_update | |
| def add_galaxy_to_advanced_search(n_clicks, galaxy_data, current_children, count): | |
| """Add the current galaxy's RA/Dec as a new query in advanced search.""" | |
| if not n_clicks or not galaxy_data: | |
| return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update | |
| # Extract galaxy coordinates | |
| ra = galaxy_data.get('ra') | |
| dec = galaxy_data.get('dec') | |
| if ra is None or dec is None: | |
| return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update | |
| # Create a new image input row with the galaxy's RA/Dec pre-filled | |
| new_row = create_vector_input_row( | |
| index=count, | |
| query_type="image", | |
| ra=ra, | |
| dec=dec, | |
| fov=0.025 | |
| ) | |
| current_children.append(new_row) | |
| # Switch to advanced mode and add the new query | |
| return ( | |
| {"display": "block"}, # Show advanced interface | |
| {"display": "none"}, # Hide basic search bar | |
| current_children, # Updated children with new query | |
| count + 1, # Incremented count | |
| False # Close modal | |
| ) | |
| def download_csv(n_clicks_basic, n_clicks_advanced, search_data): | |
| """Download search results as CSV.""" | |
| if (n_clicks_basic or n_clicks_advanced) and search_data: | |
| # Create DataFrame with the search results | |
| df = pd.DataFrame({ | |
| ZILLIZ_PRIMARY_KEY: search_data[ZILLIZ_PRIMARY_KEY], | |
| 'ra': search_data['ra'], | |
| 'dec': search_data['dec'], | |
| 'r_mag': search_data['r_mag'], | |
| 'query_similarity': search_data['distance'], | |
| 'cutout_url': search_data['cutout_url'] | |
| }) | |
| # Create CSV string | |
| csv_string = df.to_csv(index=False) | |
| # Return download data | |
| return dict(content=csv_string, filename="galaxy_search_results.csv") | |
| return dash.no_update | |
| def perform_advanced_search_logic(query, rmag_range, main_operation, main_query_type, | |
| main_ra, main_dec, | |
| additional_query_types, additional_text_values, | |
| additional_ra_values, additional_dec_values, | |
| additional_operations): | |
| """Perform advanced search combining main query with additional queries.""" | |
| # Generate unique request_id for this search | |
| request_id = uuid.uuid4().hex | |
| def operation_to_weight(op_str): | |
| """Convert operation string to float weight.""" | |
| if op_str == "+": | |
| return 1.0 | |
| elif op_str == "-": | |
| return -1.0 | |
| else: | |
| return float(op_str) | |
| def weight_to_display(weight): | |
| """Convert weight back to display string.""" | |
| if weight == 1.0: | |
| return "+" | |
| elif weight == -1.0: | |
| return "-" | |
| elif weight > 0: | |
| return f"+{int(weight)}" | |
| else: | |
| return str(int(weight)) | |
| # Parse main query | |
| text_queries = [] | |
| text_weights = [] | |
| image_queries = [] | |
| image_weights = [] | |
| main_weight = operation_to_weight(main_operation) | |
| if main_query_type == "text": | |
| if query and query.strip(): | |
| text_queries.append(query.strip()) | |
| text_weights.append(main_weight) | |
| else: # image | |
| if main_ra is not None and main_dec is not None: | |
| image_queries.append({ | |
| 'ra': float(main_ra), | |
| 'dec': float(main_dec), | |
| 'fov': 0.025 | |
| }) | |
| image_weights.append(main_weight) | |
| # Parse additional queries | |
| for i, query_type in enumerate(additional_query_types): | |
| operation = additional_operations[i] | |
| weight = operation_to_weight(operation) | |
| if query_type == "text": | |
| text_value = additional_text_values[i] | |
| if text_value and text_value.strip(): | |
| text_queries.append(text_value.strip()) | |
| text_weights.append(weight) | |
| else: # image | |
| ra = additional_ra_values[i] | |
| dec = additional_dec_values[i] | |
| if ra is not None and dec is not None: | |
| image_queries.append({ | |
| 'ra': float(ra), | |
| 'dec': float(dec), | |
| 'fov': 0.025 | |
| }) | |
| image_weights.append(weight) | |
| # Validate that we have at least one query | |
| if not text_queries and not image_queries: | |
| return "", dbc.Alert("Please enter at least one text or image query", color="warning"), None, True, True, None | |
| try: | |
| # Extract min and max from slider range | |
| rmag_min, rmag_max = rmag_range if rmag_range else (None, None) | |
| # Perform advanced search | |
| start_time = time.time() | |
| df = search_service.search_advanced( | |
| text_queries=text_queries if text_queries else None, | |
| text_weights=text_weights if text_weights else None, | |
| image_queries=image_queries if image_queries else None, | |
| image_weights=image_weights if image_weights else None, | |
| rmag_min=rmag_min, | |
| rmag_max=rmag_max | |
| ) | |
| search_time = time.time() - start_time | |
| # Log query to XML/CSV | |
| from src.utils import build_query_xml, log_query_to_csv | |
| query_xml = build_query_xml( | |
| text_queries=text_queries if text_queries else None, | |
| text_weights=text_weights if text_weights else None, | |
| image_queries=image_queries if image_queries else None, | |
| image_weights=image_weights if image_weights else None, | |
| rmag_min=rmag_min, | |
| rmag_max=rmag_max | |
| ) | |
| log_query_to_csv(query_xml, request_id=request_id) | |
| # Build results grid | |
| grid_items = build_galaxy_grid(df.head(DEFAULT_DISPLAY_COUNT)) | |
| # Build query description for storage (simple text) | |
| query_desc_parts = [] | |
| for query, weight in zip(text_queries, text_weights): | |
| op_display = weight_to_display(weight) | |
| query_desc_parts.append(f"{op_display} text:'{query}'") | |
| for img_query, weight in zip(image_queries, image_weights): | |
| op_display = weight_to_display(weight) | |
| query_desc_parts.append(f"{op_display} image:(RA={img_query['ra']:.2f}, Dec={img_query['dec']:.2f})") | |
| query_description = " ".join(query_desc_parts) | |
| # Build query display with thumbnails for images | |
| query_display_parts = [] | |
| for query, weight in zip(text_queries, text_weights): | |
| op_display = weight_to_display(weight) | |
| query_display_parts.append(html.Span(f"{op_display} text:'{query}' ", style={"margin-right": "8px"})) | |
| for img_query, weight in zip(image_queries, image_weights): | |
| op_display = weight_to_display(weight) | |
| # Generate thumbnail URL | |
| from src.utils import cutout_url | |
| thumbnail_url = cutout_url( | |
| img_query['ra'], | |
| img_query['dec'], | |
| fov=img_query.get('fov', 0.025), | |
| size=64 | |
| ) | |
| query_display_parts.append(html.Span([ | |
| f"{op_display} ", | |
| html.Img( | |
| src=thumbnail_url, | |
| style={ | |
| "width": "128px", | |
| "height": "128px", | |
| "vertical-align": "middle", | |
| "margin": "0 4px", | |
| "border-radius": "4px", | |
| "border": "1px solid rgba(255, 255, 255, 0.2)" | |
| } | |
| ) | |
| ], style={"margin-right": "8px", "display": "inline-block"})) | |
| # Build filter description | |
| filter_desc = "" | |
| if rmag_min is not None and rmag_max is not None and (rmag_min != 13.0 or rmag_max != 20.0): | |
| filter_desc = f" + r-mag: [{rmag_min:.1f}, {rmag_max:.1f}]" | |
| # Prepare data for store | |
| search_data = prepare_search_data(df, query_description, is_vector_search=True, request_id=request_id) | |
| search_data["text_queries"] = text_queries | |
| search_data["text_weights"] = text_weights | |
| search_data["image_queries"] = image_queries | |
| search_data["image_weights"] = image_weights | |
| # Create load more button | |
| load_more_button = create_load_more_button(len(df), DEFAULT_DISPLAY_COUNT) if len(df) > DEFAULT_DISPLAY_COUNT else None | |
| # Build results container | |
| results_container = html.Div([ | |
| html.Div([ | |
| html.P(f"Top {len(df)} matching galaxies (showing {min(DEFAULT_DISPLAY_COUNT, len(df))})", | |
| className="results-header mb-2 text-center d-inline-block me-2"), | |
| dbc.Button( | |
| [html.I(className="fas fa-link me-1"), "Copy link"], | |
| id="copy-results-link", | |
| color="link", | |
| size="sm", | |
| n_clicks=0, | |
| className="info-button" | |
| ), | |
| html.Span(id="copy-results-feedback", style={"marginLeft": "8px", "color": "#28a745", "fontSize": "0.8rem"}) | |
| ], className="text-center"), | |
| html.P( | |
| query_display_parts + ([f"{filter_desc}"] if filter_desc else []), | |
| className="text-center mb-3", | |
| style={"color": "rgba(245, 245, 247, 0.6)", "font-size": "0.9rem"} | |
| ), | |
| dbc.Row(grid_items, justify="center", id="search-results-grid"), | |
| load_more_button | |
| ]) | |
| # Store search params for URL generation (convert image_queries dict to tuple format) | |
| search_params = { | |
| "text_queries": text_queries, | |
| "text_weights": text_weights, | |
| "image_queries": [(img['ra'], img['dec'], img.get('fov', 0.025)) for img in image_queries], | |
| "image_weights": image_weights, | |
| "rmag_min": rmag_min, | |
| "rmag_max": rmag_max | |
| } | |
| return "", results_container, search_data, False, False, search_params | |
| except Exception as e: | |
| error_msg = dbc.Alert(f"Advanced search failed: {str(e)}", color="danger") | |
| logger.error(f"Advanced search error: {e}") | |
| logger.error(f"Full traceback:\n{traceback.format_exc()}") | |
| # Log error | |
| from src.utils import build_query_xml, log_query_to_csv | |
| try: | |
| query_xml = build_query_xml( | |
| text_queries=text_queries if text_queries else None, | |
| text_weights=text_weights if text_weights else None, | |
| image_queries=image_queries if image_queries else None, | |
| image_weights=image_weights if image_weights else None, | |
| rmag_min=rmag_range[0] if rmag_range else None, | |
| rmag_max=rmag_range[1] if rmag_range else None | |
| ) | |
| log_query_to_csv( | |
| query_xml, | |
| request_id=request_id, | |
| error_occurred=True, | |
| error_message=str(e), | |
| error_type=type(e).__name__ | |
| ) | |
| except: | |
| pass | |
| return "", error_msg, None, True, True, None | |
| # ===== URL State Management Callbacks ===== | |
| def restore_state_from_url(search_string): | |
| """Parse URL and populate UI, then trigger search if state is present.""" | |
| from src.url_state import parse_url_search_param | |
| # Parse URL state | |
| state = parse_url_search_param(search_string) | |
| # If no state, return defaults (don't trigger search) | |
| if not state.get('text_queries') and not state.get('image_queries'): | |
| return ( | |
| dash.no_update, # search-input-advanced | |
| dash.no_update, # main-vector-operation | |
| dash.no_update, # main-query-type | |
| dash.no_update, # main-vector-ra | |
| dash.no_update, # main-vector-dec | |
| dash.no_update, # vector-inputs | |
| dash.no_update, # rmag-slider | |
| dash.no_update, # advanced-search-interface style | |
| dash.no_update, # basic-search-bar style | |
| dash.no_update, # vector-arrow | |
| dash.no_update, # url-search-trigger | |
| None, # pending-expand-galaxy | |
| dash.no_update # search-input | |
| ) | |
| # Parse state | |
| text_queries = state.get('text_queries', []) | |
| text_weights = state.get('text_weights', []) | |
| image_queries = state.get('image_queries', []) | |
| image_weights = state.get('image_weights', []) | |
| rmag_min = state.get('rmag_min', 13.0) | |
| rmag_max = state.get('rmag_max', 20.0) | |
| expand_galaxy = state.get('expand_galaxy') | |
| # Determine if this is a basic search (single text query with weight +1) or advanced search | |
| is_basic_search = ( | |
| len(text_queries) == 1 and | |
| len(text_weights) == 1 and | |
| text_weights[0] == 1.0 and | |
| len(image_queries) == 0 | |
| ) | |
| if is_basic_search: | |
| # For basic search, just populate the basic search input and stay in basic mode | |
| return ( | |
| dash.no_update, # search-input-advanced | |
| dash.no_update, # main-vector-operation | |
| dash.no_update, # main-query-type | |
| dash.no_update, # main-vector-ra | |
| dash.no_update, # main-vector-dec | |
| dash.no_update, # vector-inputs | |
| [rmag_min, rmag_max], # rmag-slider | |
| {"display": "none"}, # advanced-search-interface (hide) | |
| {"display": "block"}, # basic-search-bar (show) | |
| "fas fa-chevron-down me-2", # vector-arrow (down arrow for basic mode) | |
| 1, # url-search-trigger (trigger search) | |
| expand_galaxy, # pending-expand-galaxy | |
| text_queries[0] if text_queries else "" # search-input (populate basic search) | |
| ) | |
| # For advanced search, populate all the advanced fields | |
| main_text = "" | |
| main_operation = "+" | |
| main_query_type = "text" | |
| main_ra = None | |
| main_dec = None | |
| vector_children = [] | |
| all_queries = [] | |
| # Combine text and image queries with their types | |
| for i, (query, weight) in enumerate(zip(text_queries, text_weights)): | |
| all_queries.append(('text', query, None, None, weight)) | |
| for i, (img_query, weight) in enumerate(zip(image_queries, image_weights)): | |
| all_queries.append(('image', None, img_query[0], img_query[1], weight)) | |
| if all_queries: | |
| # First query becomes main query | |
| first_type, first_text, first_ra, first_dec, first_weight = all_queries[0] | |
| main_query_type = first_type | |
| main_operation = weight_to_operation_str(first_weight) | |
| if first_type == 'text': | |
| main_text = first_text | |
| else: | |
| main_ra = first_ra | |
| main_dec = first_dec | |
| # Remaining queries become additional vectors | |
| for idx, (qtype, qtext, qra, qdec, qweight) in enumerate(all_queries[1:]): | |
| from src.components import create_vector_input_row | |
| operation_str = weight_to_operation_str(qweight) | |
| if qtype == 'text': | |
| row = create_vector_input_row( | |
| idx, | |
| query_type='text', | |
| text_value=qtext, | |
| operation=operation_str | |
| ) | |
| vector_children.append(row) | |
| else: | |
| row = create_vector_input_row( | |
| idx, | |
| query_type='image', | |
| ra=qra, | |
| dec=qdec, | |
| operation=operation_str | |
| ) | |
| vector_children.append(row) | |
| return ( | |
| main_text, # search-input-advanced | |
| main_operation, # main-vector-operation | |
| main_query_type, # main-query-type | |
| main_ra, # main-vector-ra | |
| main_dec, # main-vector-dec | |
| vector_children, # vector-inputs | |
| [rmag_min, rmag_max], # rmag-slider | |
| {"display": "block"}, # advanced-search-interface (show) | |
| {"display": "none"}, # basic-search-bar (hide) | |
| "fas fa-chevron-up me-2", # vector-arrow (up arrow for advanced mode) | |
| 1, # url-search-trigger (trigger search) | |
| expand_galaxy, # pending-expand-galaxy | |
| dash.no_update # search-input (don't update in advanced mode) | |
| ) | |
| def trigger_search_from_url(trigger, current_basic_clicks, current_advanced_clicks, basic_style): | |
| """Click appropriate search button after URL state is restored.""" | |
| if trigger: | |
| # Determine which mode we're in based on which search bar is visible | |
| is_basic_mode = basic_style.get("display") == "block" | |
| if is_basic_mode: | |
| return (current_basic_clicks or 0) + 1, dash.no_update | |
| else: | |
| return dash.no_update, (current_advanced_clicks or 0) + 1 | |
| return dash.no_update, dash.no_update | |
| def update_url_after_search(search_params, current_url): | |
| """Update browser URL without reload after search completes.""" | |
| from src.url_state import encode_search_state | |
| from src.config import URL_STATE_PARAM | |
| if not search_params: | |
| return dash.no_update | |
| # Extract params | |
| text_queries = search_params.get('text_queries', []) | |
| text_weights = search_params.get('text_weights', []) | |
| image_queries = search_params.get('image_queries', []) | |
| image_weights = search_params.get('image_weights', []) | |
| rmag_min = search_params.get('rmag_min', 13.0) | |
| rmag_max = search_params.get('rmag_max', 20.0) | |
| # Encode state | |
| encoded = encode_search_state( | |
| text_queries=text_queries, | |
| text_weights=text_weights, | |
| image_queries=image_queries, | |
| image_weights=image_weights, | |
| rmag_min=rmag_min, | |
| rmag_max=rmag_max | |
| ) | |
| # Build new URL search string | |
| new_url = f"?{URL_STATE_PARAM}={encoded}" | |
| # Only update if different from current | |
| if new_url == current_url: | |
| raise dash.exceptions.PreventUpdate | |
| return new_url | |
| # Clientside callback for copying results link to clipboard | |
| app.clientside_callback( | |
| """ | |
| function(n_clicks, search_params) { | |
| if (!n_clicks || !search_params) { | |
| return ""; | |
| } | |
| // Build the URL from search params | |
| var state = {}; | |
| if (search_params.text_queries && search_params.text_queries.length > 0) { | |
| state.tq = search_params.text_queries; | |
| state.tw = search_params.text_weights; | |
| } | |
| if (search_params.image_queries && search_params.image_queries.length > 0) { | |
| state.iq = search_params.image_queries; | |
| state.iw = search_params.image_weights; | |
| } | |
| if (search_params.rmag_min !== 13.0) { | |
| state.rmin = search_params.rmag_min; | |
| } | |
| if (search_params.rmag_max !== 20.0) { | |
| state.rmax = search_params.rmag_max; | |
| } | |
| var jsonStr = JSON.stringify(state); | |
| var encoded = btoa(jsonStr).replace(/=/g, ''); | |
| var url = window.location.origin + "?s=" + encoded; | |
| // Copy to clipboard using fallback method | |
| var textArea = document.createElement("textarea"); | |
| textArea.value = url; | |
| textArea.style.position = "fixed"; | |
| textArea.style.left = "-999999px"; | |
| textArea.style.top = "-999999px"; | |
| document.body.appendChild(textArea); | |
| textArea.focus(); | |
| textArea.select(); | |
| try { | |
| document.execCommand('copy'); | |
| document.body.removeChild(textArea); | |
| // Clear feedback after 2 seconds | |
| setTimeout(function() { | |
| var el = document.getElementById('copy-results-feedback'); | |
| if (el) el.textContent = ''; | |
| }, 2000); | |
| return "Copied!"; | |
| } catch (err) { | |
| document.body.removeChild(textArea); | |
| // Try modern API as fallback | |
| if (navigator.clipboard) { | |
| navigator.clipboard.writeText(url); | |
| setTimeout(function() { | |
| var el = document.getElementById('copy-results-feedback'); | |
| if (el) el.textContent = ''; | |
| }, 2000); | |
| return "Copied!"; | |
| } | |
| return "Failed to copy"; | |
| } | |
| } | |
| """, | |
| Output("copy-results-feedback", "children"), | |
| Input("copy-results-link", "n_clicks"), | |
| State("current-search-params", "data"), | |
| prevent_initial_call=True | |
| ) | |
| # Clientside callback for copying galaxy link to clipboard | |
| app.clientside_callback( | |
| """ | |
| function(n_clicks, search_params, galaxy_data) { | |
| if (!n_clicks || !search_params || !galaxy_data) { | |
| return ""; | |
| } | |
| // Build the URL from search params with galaxy expansion | |
| var state = {}; | |
| if (search_params.text_queries && search_params.text_queries.length > 0) { | |
| state.tq = search_params.text_queries; | |
| state.tw = search_params.text_weights; | |
| } | |
| if (search_params.image_queries && search_params.image_queries.length > 0) { | |
| state.iq = search_params.image_queries; | |
| state.iw = search_params.image_weights; | |
| } | |
| if (search_params.rmag_min !== 13.0) { | |
| state.rmin = search_params.rmag_min; | |
| } | |
| if (search_params.rmag_max !== 20.0) { | |
| state.rmax = search_params.rmag_max; | |
| } | |
| // Add galaxy expansion | |
| if (galaxy_data.object_id) { | |
| var galaxyId = galaxy_data.object_id; | |
| // Clean up bytes notation if present (e.g., "b'1500m885-7090'" -> "1500m885-7090") | |
| if (typeof galaxyId === 'string' && galaxyId.startsWith("b'") && galaxyId.endsWith("'")) { | |
| galaxyId = galaxyId.slice(2, -1); | |
| } | |
| state.exp = galaxyId; | |
| } | |
| var jsonStr = JSON.stringify(state); | |
| var encoded = btoa(jsonStr).replace(/=/g, ''); | |
| var url = window.location.origin + "?s=" + encoded; | |
| // Copy to clipboard using fallback method | |
| var textArea = document.createElement("textarea"); | |
| textArea.value = url; | |
| textArea.style.position = "fixed"; | |
| textArea.style.left = "-999999px"; | |
| textArea.style.top = "-999999px"; | |
| document.body.appendChild(textArea); | |
| textArea.focus(); | |
| textArea.select(); | |
| try { | |
| document.execCommand('copy'); | |
| document.body.removeChild(textArea); | |
| // Clear feedback after 2 seconds | |
| setTimeout(function() { | |
| var el = document.getElementById('copy-galaxy-feedback'); | |
| if (el) el.textContent = ''; | |
| }, 2000); | |
| return "Copied!"; | |
| } catch (err) { | |
| document.body.removeChild(textArea); | |
| // Try modern API as fallback | |
| if (navigator.clipboard) { | |
| navigator.clipboard.writeText(url); | |
| setTimeout(function() { | |
| var el = document.getElementById('copy-galaxy-feedback'); | |
| if (el) el.textContent = ''; | |
| }, 2000); | |
| return "Copied!"; | |
| } | |
| return "Failed to copy"; | |
| } | |
| } | |
| """, | |
| Output("copy-galaxy-feedback", "children"), | |
| Input("copy-galaxy-link", "n_clicks"), | |
| [State("current-search-params", "data"), | |
| State("current-galaxy-data", "data")], | |
| prevent_initial_call=True | |
| ) | |
| # Helper functions for callbacks | |
| def weight_to_operation_str(weight): | |
| """Convert weight float to operation string for UI.""" | |
| if weight == 1.0: | |
| return "+" | |
| elif weight == -1.0: | |
| return "-" | |
| elif weight > 0: | |
| return f"+{int(weight)}" | |
| else: | |
| return str(int(weight)) | |
| def build_galaxy_grid(df: pd.DataFrame) -> list: | |
| """Build galaxy grid items from DataFrame. | |
| Args: | |
| df: DataFrame with galaxy data | |
| Returns: | |
| List of Dash components | |
| """ | |
| grid_items = [] | |
| for i, row in df.iterrows(): | |
| galaxy_info = { | |
| ZILLIZ_PRIMARY_KEY: row[ZILLIZ_PRIMARY_KEY], | |
| "ra": row['ra'], | |
| "dec": row['dec'], | |
| "distance": row['distance'], | |
| "r_mag": row['r_mag'], | |
| "cutout_url": row['cutout_url'] | |
| } | |
| grid_item = build_galaxy_card(galaxy_info, i) | |
| grid_items.append(grid_item) | |
| return grid_items | |
| def build_galaxy_card(galaxy_info: dict, index: int): | |
| """Build a single galaxy card component. | |
| Args: | |
| galaxy_info: Dictionary with galaxy information | |
| index: Index of the galaxy in the results | |
| Returns: | |
| Dash Bootstrap Col component | |
| """ | |
| return dbc.Col([ | |
| html.Div([ | |
| html.Div([ | |
| html.Img( | |
| src=galaxy_info["cutout_url"], | |
| style={ | |
| "width": IMAGE_WIDTH, | |
| "height": IMAGE_HEIGHT, | |
| "object-fit": "cover", | |
| "cursor": "pointer", | |
| "border-radius": "8px" | |
| }, | |
| id={"type": "galaxy-image", "index": index}, | |
| className="hover-shadow" | |
| ), | |
| html.Div([ | |
| html.Small(f"r = {galaxy_info['r_mag']:.2f} mag", className="score-badge") | |
| ], style={ | |
| "position": "absolute", | |
| "bottom": "8px", | |
| "right": "8px" | |
| }) | |
| ], style={"position": "relative"}) | |
| ]) | |
| ], width=6, md=4, lg=2, className="mb-2 px-1") | |
| def prepare_search_data(df: pd.DataFrame, query: str, is_vector_search: bool = False, request_id: str = None) -> dict: | |
| """Prepare search data for storage. | |
| Args: | |
| df: DataFrame with search results | |
| query: Search query string | |
| is_vector_search: Whether this is a vector search | |
| request_id: Unique ID for this request | |
| Returns: | |
| Dictionary with search data | |
| """ | |
| data = { | |
| ZILLIZ_PRIMARY_KEY: df[ZILLIZ_PRIMARY_KEY].tolist(), | |
| "ra": df['ra'].tolist(), | |
| "dec": df['dec'].tolist(), | |
| "distance": df['distance'].tolist(), | |
| "r_mag": df['r_mag'].tolist(), | |
| "cutout_url": df['cutout_url'].tolist(), | |
| "loaded_count": DEFAULT_DISPLAY_COUNT, | |
| "query": query, | |
| "is_vector_search": is_vector_search | |
| } | |
| if request_id: | |
| data["request_id"] = request_id | |
| return data | |
| def extract_galaxy_info(search_data: dict, index: int) -> dict: | |
| """Extract galaxy info from search data at given index. | |
| Args: | |
| search_data: Dictionary with search data | |
| index: Index of the galaxy | |
| Returns: | |
| Dictionary with galaxy information | |
| """ | |
| return { | |
| ZILLIZ_PRIMARY_KEY: search_data[ZILLIZ_PRIMARY_KEY][index], | |
| "ra": search_data["ra"][index], | |
| "dec": search_data["dec"][index], | |
| "distance": search_data["distance"][index], | |
| "r_mag": search_data["r_mag"][index], | |
| "cutout_url": search_data["cutout_url"][index] | |
| } | |
| def build_modal_content(galaxy_info: dict) -> tuple: | |
| """Build modal image and description content. | |
| Args: | |
| galaxy_info: Dictionary with galaxy information | |
| Returns: | |
| Tuple of (image_element, description_element) | |
| """ | |
| image_element = html.Img( | |
| src=galaxy_info["cutout_url"], | |
| style={"width": "100%", "max-width": "350px", "height": "auto"} | |
| ) | |
| # Format primary key label (convert snake_case to Title Case) | |
| pk_label = ZILLIZ_PRIMARY_KEY.replace("_", " ").title() | |
| description_element = html.Div([ | |
| html.Div([ | |
| html.Span(f"{pk_label}: {galaxy_info[ZILLIZ_PRIMARY_KEY]}", className="d-inline-block mb-0", | |
| style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}), | |
| ], className="mb-2"), | |
| html.Div([ | |
| html.Span(f"RA: {galaxy_info['ra']:.6f}", className="d-inline-block mb-0", | |
| style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}), | |
| html.Span(" β’ ", className="mx-2", style={"color": "rgba(245, 245, 247, 0.5)"}), | |
| html.Span(f"Dec: {galaxy_info['dec']:.6f}", className="d-inline-block mb-0", | |
| style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}), | |
| ], className="mb-2"), | |
| html.Div([ | |
| html.Span(f"r_mag: {galaxy_info['r_mag']:.2f}", className="d-inline-block mb-0", | |
| style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}), | |
| html.Span(" β’ ", className="mx-2", style={"color": "rgba(245, 245, 247, 0.5)"}), | |
| html.Span(f"Cosine Similarity to Query: {galaxy_info['distance']:.4f}", className="d-inline-block mb-0", | |
| style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}), | |
| ], className="mb-3"), | |
| ]) | |
| return image_element, description_element | |
| def create_load_more_button(total_count: int, current_count: int): | |
| """Create a load more button. | |
| Args: | |
| total_count: Total number of results | |
| current_count: Number of currently loaded results | |
| Returns: | |
| Dash Bootstrap Button component | |
| """ | |
| remaining = total_count - current_count | |
| button_text = f"Load next {min(LOAD_MORE_COUNT, remaining)} galaxies" | |
| return dbc.Button( | |
| button_text, | |
| id="load-more-button", | |
| color="secondary", | |
| className="mt-3", | |
| style={"width": "100%"} | |
| ) | |