Spaces:
Sleeping
Sleeping
File size: 13,359 Bytes
dc017a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 |
"""Results stage for the Loci Similes GUI."""
from __future__ import annotations
import csv
import io
import re
from typing import TYPE_CHECKING
try:
import gradio as gr
except ImportError as exc:
missing = getattr(exc, "name", None)
base_msg = (
"Optional GUI dependencies are missing. Install them via "
"'pip install locisimiles[gui]' (Python 3.13+ also requires the "
"audioop-lts backport) to use the Gradio interface."
)
if missing and missing != "gradio":
raise ImportError(f"{base_msg} (missing package: {missing})") from exc
raise ImportError(base_msg) from exc
if TYPE_CHECKING:
from locisimiles.document import Document, TextSegment
import tempfile
from typing import Any, Dict, List, Tuple
try:
import gradio as gr
except ImportError as exc:
raise ImportError("Gradio is required for results page") from exc
from locisimiles.document import Document, TextSegment
# Type aliases from pipeline
FullDict = Dict[str, List[Tuple[TextSegment, float, float]]]
def update_results_display(results: FullDict | None, query_doc: Document | None, threshold: float = 0.5) -> tuple[dict, dict, dict]:
"""Update the results display with new data.
Args:
results: Pipeline results
query_doc: Query document
threshold: Classification probability threshold for counting finds
Returns:
Tuple of (query_segments_update, query_segments_state, matches_dict_state)
"""
query_segments, matches_dict = _convert_results_to_display(results, query_doc, threshold)
return (
gr.update(value=query_segments), # Update query segments dataframe
query_segments, # Update query segments state
matches_dict, # Update matches dict state
)
def _format_metric_with_bar(value: float, is_above_threshold: bool = False) -> str:
"""Format a metric value with a visual progress bar.
Args:
value: Metric value between 0 and 1
is_above_threshold: Whether to highlight this value
Returns:
HTML string with progress bar
"""
percentage = int(value * 100)
# Choose color based on threshold
if is_above_threshold:
bar_color = "#6B9BD1" # Blue accent for findings
bg_color = "#E3F2FD" # Light blue background
else:
bar_color = "#B0B0B0" # Gray for below threshold
bg_color = "#F5F5F5" # Light gray background
html = f'''
<div style="display: flex; align-items: center; gap: 8px; width: 100%;">
<div style="flex: 1; background-color: {bg_color}; border-radius: 4px; overflow: hidden; height: 20px; position: relative;">
<div style="background-color: {bar_color}; width: {percentage}%; height: 100%; transition: width 0.3s;"></div>
</div>
<span style="min-width: 45px; text-align: right; font-weight: {'bold' if is_above_threshold else 'normal'};">{value:.3f}</span>
</div>
'''
return html
def _convert_results_to_display(results: FullDict | None, query_doc: Document | None, threshold: float = 0.5) -> tuple[list[list], dict]:
"""Convert pipeline results to display format.
Args:
results: Pipeline results (FullDict format)
query_doc: Query document
threshold: Classification probability threshold for counting finds
Returns:
Tuple of (query_segments_list, matches_dict)
"""
if results is None or query_doc is None:
# Return empty data if no results
return [], {}
# First pass: Create raw matches dictionary and count finds
raw_matches = {}
find_counts = {}
for query_id, match_list in results.items():
# Sort by probability (descending) to show most likely matches first
sorted_matches = sorted(match_list, key=lambda x: x[2], reverse=True) # x[2] is probability
# Store raw numeric values
raw_matches[query_id] = sorted_matches
# Count finds above threshold
find_counts[query_id] = sum(1 for _, _, prob in sorted_matches if prob >= threshold)
# Convert query document to list format with find counts
# Document is iterable and returns TextSegments in order
query_segments = []
for segment in query_doc:
find_count = find_counts.get(segment.id, 0)
query_segments.append([segment.id, segment.text, find_count])
# Second pass: Format matches with HTML progress bars
matches_dict = {}
for query_id, match_list in raw_matches.items():
matches_dict[query_id] = [
[
source_seg.id,
source_seg.text,
_format_metric_with_bar(round(similarity, 3), probability >= threshold),
_format_metric_with_bar(round(probability, 3), probability >= threshold)
]
for source_seg, similarity, probability in match_list
]
return query_segments, matches_dict
def _on_query_select(evt: gr.SelectData, query_segments: list, matches_dict: dict) -> tuple[dict, dict]:
"""Handle query segment selection and return matching source segments.
Note: evt.index[0] gives the row number when clicking anywhere in that row.
Args:
evt: Selection event data
query_segments: List of query segments
matches_dict: Dictionary mapping query IDs to matches
Returns:
A tuple of (prompt_visibility_update, dataframe_update_with_data)
"""
if evt.index is None or len(evt.index) < 1:
return gr.update(visible=True), gr.update(visible=False)
row_index = evt.index[0]
if row_index >= len(query_segments):
return gr.update(visible=True), gr.update(visible=False)
segment_id = query_segments[row_index][0]
matches = matches_dict.get(segment_id, [])
# Hide prompt, show dataframe with results
return gr.update(visible=False), gr.update(value=matches, visible=True)
def _extract_numeric_from_html(html_str: str) -> float:
"""Extract numeric value from HTML formatted metric string.
Args:
html_str: HTML string with embedded numeric value
Returns:
Extracted numeric value
"""
import re
# Extract the number from the span at the end: <span ...>0.XXX</span>
match = re.search(r'<span[^>]*>([\d.]+)</span>', html_str)
if match:
return float(match.group(1))
# Fallback: if it's already a number
try:
return float(html_str)
except (ValueError, TypeError):
return 0.0
def _export_results_to_csv(query_segments: list, matches_dict: dict, threshold: float) -> str:
"""Export results to a CSV file.
Args:
query_segments: List of query segments with find counts
matches_dict: Dictionary mapping query IDs to matches
threshold: Classification probability threshold
Returns:
Path to the temporary CSV file
"""
# Create a temporary file
temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv', newline='', encoding='utf-8')
with temp_file as f:
writer = csv.writer(f)
# Write header
writer.writerow([
"Query_Segment_ID",
"Query_Text",
"Source_Segment_ID",
"Source_Text",
"Similarity",
"Probability",
"Above_Threshold"
])
# Write data for each query segment
for query_row in query_segments:
query_id = query_row[0]
query_text = query_row[1]
# Get matches for this query segment
matches = matches_dict.get(query_id, [])
if matches:
for match in matches:
source_id = match[0]
source_text = match[1]
# Extract numeric values from HTML formatted strings
similarity = _extract_numeric_from_html(match[2]) if isinstance(match[2], str) else match[2]
probability = _extract_numeric_from_html(match[3]) if isinstance(match[3], str) else match[3]
above_threshold = "Yes" if probability >= threshold else "No"
writer.writerow([
query_id,
query_text,
source_id,
source_text,
similarity,
probability,
above_threshold
])
else:
# Write row even if no matches
writer.writerow([
query_id,
query_text,
"",
"",
"",
"",
""
])
return temp_file.name
def build_results_stage() -> tuple[gr.Step, dict[str, Any]]:
"""Build the results stage UI.
Returns:
A tuple of (Step component, components_dict) where components_dict contains
references to all interactive components that need to be accessed later.
"""
with gr.Step("Results", id=2) as step:
# State to hold current query segments and matches
query_segments_state = gr.State(value=[])
matches_dict_state = gr.State(value={})
gr.Markdown("### π Step 3: View Results")
gr.Markdown(
"Select a query segment on the left to view potential intertextual references from the source document. "
"Similarity measures the cosine similarity between embeddings (0-1, higher = more similar). "
"Probability is the classifier's confidence that the pair represents an intertextual reference (0-1, higher = more likely)."
)
# Download button
with gr.Row():
download_btn = gr.DownloadButton("Download Results as CSV", variant="primary")
with gr.Row():
# Left column: Query segments
with gr.Column(scale=1):
gr.Markdown("### Query Document Segments")
query_segments = gr.Dataframe(
value=[],
headers=["Segment ID", "Text", "Finds"],
interactive=False,
show_label=False,
label="Query Document Segments",
wrap=True,
max_height=600,
col_count=(3, "fixed"),
)
# Right column: Matching source segments
with gr.Column(scale=1):
gr.Markdown("### Potential Intertextual References")
# Prompt shown initially
selection_prompt = gr.Markdown(
"""
<div style="display: flex; align-items: center; justify-content: center; height: 400px; font-size: 18px; color: #666;">
<div style="text-align: center;">
<div style="font-size: 48px; margin-bottom: 20px;">β</div>
<div>Select a query segment to view</div>
<div>potential intertextual references</div>
</div>
</div>
""",
visible=True
)
# Dataframe hidden initially
source_matches = gr.Dataframe(
headers=["Source ID", "Source Text", "Similarity", "Probability"],
interactive=False,
show_label=False,
label="Potential Intertextual References from Source Document",
wrap=True,
max_height=600,
visible=False,
datatype=["str", "str", "html", "html"], # Enable HTML rendering for metric columns
)
with gr.Row():
restart_btn = gr.Button("β Start Over", size="lg")
# Return the step and all components that need to be accessed
components = {
"query_segments": query_segments,
"query_segments_state": query_segments_state,
"matches_dict_state": matches_dict_state,
"source_matches": source_matches,
"selection_prompt": selection_prompt,
"download_btn": download_btn,
"restart_btn": restart_btn,
}
return step, components
def setup_results_handlers(components: dict, walkthrough: gr.Walkthrough) -> None:
"""Set up event handlers for the results stage.
Args:
components: Dictionary of UI components from build_results_stage
walkthrough: The Walkthrough component for navigation
"""
# Selection handler for query segments
components["query_segments"].select(
fn=_on_query_select,
inputs=[components["query_segments_state"], components["matches_dict_state"]],
outputs=[components["selection_prompt"], components["source_matches"]],
)
# Restart button: Step 3 β Step 1
components["restart_btn"].click(
fn=lambda: gr.Walkthrough(selected=0),
outputs=walkthrough,
)
|