mk1985 commited on
Commit
e9738aa
Β·
verified Β·
1 Parent(s): 0e88058

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +266 -243
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # πŸ“š Install dependencies
2
  # Make sure to run this in your environment if you haven't already
3
- # !pip install openai anthropic google-generativeai gradio transformers torch gliner pandas --quiet
4
 
5
  # βš™οΈ Imports
6
  import openai
@@ -8,12 +8,10 @@ import anthropic
8
  import google.generativeai as genai
9
  import gradio as gr
10
  from gliner import GLiNER
11
- import traceback
12
  from collections import defaultdict, Counter
13
- import re
14
  import os
15
- import pandas as pd
16
- import tempfile
17
 
18
  # 🧠 Supported models and their providers
19
  MODEL_OPTIONS = {
@@ -27,317 +25,342 @@ GLINER_MODEL_NAME = "urchade/gliner_large-v2.1"
27
 
28
  # --- Load the model only once at startup ---
29
  try:
30
- print("Loading GLiNER model... This may take a moment.")
31
  gliner_model = GLiNER.from_pretrained(GLINER_MODEL_NAME)
32
- print("GLiNER model loaded successfully.")
33
  except Exception as e:
34
  print(f"FATAL ERROR: Could not load GLiNER model. The app will not be able to find entities. Error: {e}")
35
  gliner_model = None
36
 
37
- # --- Prompt and other constants remain the same ---
38
- HIERARCHICAL_PROMPT_TEMPLATE = """You are a helpful research assistant specializing in history. Your task is to brainstorm a hierarchical set of keywords and named entities related to a historical topic.
39
-
40
- The user will provide a topic. You should generate a structured list of categories and, for each category, a comma-separated list of relevant keywords or phrases. These keywords should be things a researcher might want to search for in a historical text.
41
-
42
- Rules:
43
- 1. Structure your response using Markdown.
44
- 2. Use '###' for each category title (e.g., '### Key Figures').
45
- 3. Beneath each category, provide a single bullet point '-' followed by a comma-separated list of 5-10 specific keywords or entities.
46
- 4. Do not add any introductory or concluding sentences. Just provide the structured list.
47
- 5. The keywords should be specific and likely to appear in primary or secondary source documents.
48
-
49
- Example for the topic "The Protestant Reformation":
50
- ### Key Figures
51
- - Martin Luther, John Calvin, Huldrych Zwingli, Henry VIII, Charles V, Pope Leo X
52
- ### Core Theological Concepts
53
- - Sola Scriptura, Sola Fide, Indulgences, Priesthood of all believers, Justification by faith
54
- ### Key Events
55
- - Diet of Worms, Ninety-five Theses, Marburg Colloquy, Council of Trent, Edict of Worms
56
- ### Important Locations
57
- - Wittenberg, Geneva, Rome, Wartburg Castle, Augsburg
58
- ### Associated Groups
59
- - Protestants, Lutherans, Calvinists, Anabaptists, Huguenots, Catholic Church
60
-
61
- Now, generate the framework for the following topic:
62
- Topic: {topic}"""
63
- TRADITIONAL_NER_LABELS = ["PERSON", "ORGANIZATION", "LOCATION", "DATE", "EVENT", "WORK_OF_ART", "LAW"]
64
- MAX_CATEGORIES = 8
65
-
66
- with gr.Blocks(title="Historical Text Analysis Tool", css=".prose { word-break: break-word; }") as demo:
67
- # --- UI remains the same up to the output tabs ---
68
- gr.Markdown("# Historical Text Analysis Tool")
69
- gr.Markdown("A tool to help historians and researchers quickly identify key terms and concepts in texts. Start by generating keyword ideas for a topic, then paste your text to find all occurrences.") # Welcome text collapsed for brevity
70
- gr.Markdown("---")
71
- gr.Markdown("## Step 1: Get Keyword Ideas")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  with gr.Row():
73
- topic = gr.Textbox(label="Enter Historical Topic", placeholder="e.g., The Chartist Movement")
74
- provider = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Choose AI Model")
75
  with gr.Row():
76
  openai_key = gr.Textbox(label="OpenAI API Key", type="password")
77
  anthropic_key = gr.Textbox(label="Anthropic API Key", type="password")
78
  google_key = gr.Textbox(label="Google API Key", type="password")
79
- generate_btn = gr.Button("Suggest Categories and Keywords", variant="primary")
80
 
81
- gr.Markdown("--- \n## Step 2: Build Your Search and Analyze Text")
82
- category_components = []
 
 
 
 
 
 
 
 
 
 
 
 
83
  with gr.Column():
84
  for i in range(MAX_CATEGORIES):
85
- with gr.Accordion(f"Category {i+1}", visible=False) as acc:
86
  with gr.Row():
87
- cg = gr.CheckboxGroup(label="Keywords", interactive=True, container=False, scale=4)
88
- toggle_btn = gr.Button("Deselect All", size="sm", scale=1, min_width=100)
89
- category_components.append((acc, cg, toggle_btn))
 
 
 
90
  with gr.Group():
91
- ner_output = gr.CheckboxGroup(choices=TRADITIONAL_NER_LABELS, value=TRADITIONAL_NER_LABELS, label="Standard Search Terms")
92
- toggle_ner_btn = gr.Button("Deselect All", size="sm")
 
 
 
 
 
93
  with gr.Group():
94
- custom_labels = gr.Textbox(label="Add Your Own Keywords (Optional)", placeholder="e.g., Technology, Weapon... (separated by commas)")
95
- threshold_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold", info="Controls how 'sure' the AI needs to be. Lower finds more potential matches, higher finds only the most certain ones.")
96
- text_input = gr.Textbox(label="Paste Your Full Text Here for Analysis", lines=10)
97
- match_btn = gr.Button("Find Keywords in Text", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- # --- NEW: Add state variables to hold data between function calls ---
100
- # This holds the original text for updates
101
- text_state = gr.State()
102
- # This holds the results DataFrame for updates and downloads
103
- dataframe_state = gr.State()
104
-
105
  with gr.Tabs():
106
  with gr.TabItem("Highlighted Text"):
107
- matched_output = gr.HighlightedText(
108
- label="Keyword Matches",
109
- interactive=True,
110
- show_legend=True
111
- )
112
  with gr.TabItem("Detailed Results"):
113
- # --- CHANGE: Using gr.DataFrame for a clean table output ---
114
- detailed_results_output = gr.DataFrame(
115
- headers=["Category", "Found Phrase", "Occurrences"],
116
- datatype=["str", "str", "number"],
117
- wrap=True,
118
- label="Aggregated Results"
119
- )
120
- # --- NEW: Download button and hidden file component ---
121
- download_button = gr.Button("Download Results as CSV", visible=False)
122
- download_file = gr.File(label="Download", visible=False)
123
-
124
  with gr.TabItem("Debug Info"):
125
  debug_output = gr.Textbox(label="Extraction Log", interactive=False, lines=8)
126
 
127
  # --- Backend Functions ---
128
 
129
- # --- THIS IS THE MISSING FUNCTION THAT WAS ADDED ---
130
- def generate_from_prompt(prompt, provider, key_dict):
131
- """Calls the appropriate LLM API based on the selected provider."""
132
- provider_id = MODEL_OPTIONS.get(provider)
133
-
134
- if provider_id == "openai":
135
- client = openai.OpenAI(api_key=key_dict["openai_key"])
136
- response = client.chat.completions.create(
137
- model="gpt-4o",
138
- messages=[{"role": "user", "content": prompt}]
139
- )
140
- return response.choices[0].message.content
141
-
142
- elif provider_id == "anthropic":
143
- client = anthropic.Anthropic(api_key=key_dict["anthropic_key"])
144
- response = client.messages.create(
145
- model="claude-3-opus-20240229",
146
- max_tokens=1024,
147
- messages=[{"role": "user", "content": prompt}]
148
- )
149
- return response.content[0].text
150
-
151
- elif provider_id == "google":
152
- genai.configure(api_key=key_dict["google_key"])
153
- model = genai.GenerativeModel('gemini-1.5-pro-latest')
154
- response = model.generate_content(prompt)
155
- return response.text
156
-
157
- else:
158
- raise ValueError("Invalid provider selected")
159
-
160
  def handle_generate(topic, provider, openai_k, anthropic_k, google_k):
161
- # ... (This function remains unchanged) ...
162
- yield {generate_btn: gr.update(value="Consulting the Archives...", interactive=False)}
 
 
163
  try:
164
- key_dict = {"openai_key": os.environ.get("OPENAI_API_KEY", openai_k), "anthropic_key": os.environ.get("ANTHROPIC_API_KEY", anthropic_k), "google_key": os.environ.get("GOOGLE_API_KEY", google_k)}
 
 
 
 
 
165
  provider_id = MODEL_OPTIONS.get(provider)
166
- if not topic or not provider or not key_dict.get(f"{provider_id}_key"): raise gr.Error("Topic, Provider, and the correct API Key are required.")
 
 
167
  prompt = HIERARCHICAL_PROMPT_TEMPLATE.format(topic=topic)
168
  raw_framework = generate_from_prompt(prompt, provider, key_dict)
 
 
169
  framework = defaultdict(list)
170
  current_category = None
171
  for line in raw_framework.split('\n'):
172
  line = line.strip()
173
- if line.startswith("###"): current_category = line.replace("###", "").strip()
174
- elif line.startswith("-") and current_category: framework[current_category].extend([e.strip() for e in line.replace("-", "").strip().split(',') if e.strip()])
175
- if not framework: raise gr.Error("AI failed to generate categories. Please try again.")
 
 
 
 
 
 
176
  updates = {}
177
  categories = list(framework.items())
178
  for i in range(MAX_CATEGORIES):
179
- accordion_comp, checkbox_comp, toggle_btn_comp = category_components[i]
180
  if i < len(categories):
181
- category, entities = categories[i]
 
182
  sorted_entities = sorted(list(set(entities)))
183
- updates[accordion_comp] = gr.update(label=category, visible=True)
184
- updates[checkbox_comp] = gr.update(choices=sorted_entities, value=sorted_entities, visible=True)
185
- updates[toggle_btn_comp] = gr.update(visible=True, value="Deselect All")
186
  else:
187
  updates[accordion_comp] = gr.update(visible=False)
188
  updates[checkbox_comp] = gr.update(visible=False)
189
- updates[toggle_btn_comp] = gr.update(visible=False)
190
- updates[generate_btn] = gr.update(value="Suggest Categories and Keywords", interactive=True)
 
191
  yield updates
192
  except Exception as e:
193
- yield {generate_btn: gr.update(value="Suggest Categories and Keywords", interactive=True)}
194
  raise gr.Error(str(e))
195
 
196
- # --- NEW: Helper function to process entities into a DataFrame ---
197
- def process_entities_to_df(entities, original_text):
198
- """Takes a list of entities and the original text, and returns a pandas DataFrame."""
199
- if not entities:
200
- return pd.DataFrame(columns=["Category", "Found Phrase", "Occurrences"])
201
-
202
- # Extract text for each entity
203
- found_phrases = []
204
- for ent in entities:
205
- found_phrases.append({
206
- "Category": ent['entity'],
207
- "Found Phrase": original_text[ent['start']:ent['end']]
208
- })
209
-
210
- if not found_phrases:
211
- return pd.DataFrame(columns=["Category", "Found Phrase", "Occurrences"])
212
-
213
- # Aggregate using pandas
214
- df = pd.DataFrame(found_phrases)
215
- aggregated_df = df.groupby(["Category", "Found Phrase"]).size().reset_index(name="Occurrences")
216
- aggregated_df = aggregated_df.sort_values(by=["Category", "Occurrences"], ascending=[True, False])
217
-
218
- return aggregated_df
219
-
220
- # --- UPDATED: `match_entities` now uses pandas and updates state ---
221
- def match_entities(text, ner_labels, custom_label_text, threshold, *selected_keywords, progress=gr.Progress(track_tqdm=True)):
222
  yield {
223
- match_btn: gr.update(value="Searching...", interactive=False),
 
 
224
  detailed_results_output: None,
225
- download_button: gr.update(visible=False),
226
- download_file: gr.update(visible=False)
227
  }
228
- if gliner_model is None: raise gr.Error("GLiNER model failed to load.")
229
 
 
 
 
 
 
230
  labels_to_use = set()
231
- if ner_labels: labels_to_use.update(ner_labels)
232
- for group in selected_keywords:
233
  if group: labels_to_use.update(group)
 
 
 
234
  custom = {l.strip() for l in custom_label_text.split(',') if l.strip()}
235
  if custom: labels_to_use.update(custom)
236
- final_labels = sorted(list(labels_to_use))
237
- debug_info = [f"🧠 Searching for {len(final_labels)} unique keywords.", f"βš™οΈ Confidence Threshold: {threshold}"]
238
 
 
 
 
 
239
  if not text or not final_labels:
240
- yield {match_btn: gr.update(value="Find Keywords in Text", interactive=True)}
 
 
 
 
 
 
241
  return
242
-
 
243
  all_entities = []
244
- chunk_size, overlap = 1000, 50
245
- for i in progress.tqdm(range(0, len(text), chunk_size - overlap), desc="Scanning Text..."):
 
246
  chunk = text[i : i + chunk_size]
247
  chunk_entities = gliner_model.predict_entities(chunk, final_labels, threshold=threshold)
248
  for ent in chunk_entities:
249
- ent['start'] += i; ent['end'] += i
 
250
  all_entities.append(ent)
251
 
 
252
  unique_entities = [dict(t) for t in {tuple(d.items()) for d in all_entities}]
253
- debug_info.append(f"πŸ“Š Found {len(unique_entities)} unique matches.")
 
 
 
 
 
 
 
 
 
254
 
255
- highlighted_entities = [{"start": ent["start"], "end": ent["end"], "label": ent["label"]} for ent in unique_entities]
 
 
 
 
 
 
 
 
 
256
 
257
- # --- NEW: Use helper to create DataFrame ---
258
- results_df = process_entities_to_df(highlighted_entities, text)
 
 
 
 
 
 
 
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  yield {
261
- match_btn: gr.update(value="Find Keywords in Text", interactive=True),
262
- matched_output: {"text": text, "entities": highlighted_entities},
263
- detailed_results_output: results_df,
264
- debug_output: "\n".join(debug_info),
265
- download_button: gr.update(visible=True if not results_df.empty else False),
266
- text_state: text, # Store original text in state
267
- dataframe_state: results_df # Store dataframe in state
268
  }
269
 
270
- # --- NEW: Function to update results when highlighted text is edited ---
271
- def update_detailed_results(new_highlighted_entities, original_text):
272
- """
273
- This function is triggered when the user edits the HighlightedText component.
274
- It re-calculates the DataFrame and updates the UI.
275
- """
276
- # new_highlighted_entities is the full value of the component, not just a diff
277
- # In Gradio > 4, the format is a list of dictionaries with 'entity', 'start', 'end'
278
- results_df = process_entities_to_df(new_highlighted_entities, original_text)
279
-
280
- return {
281
- detailed_results_output: results_df,
282
- dataframe_state: results_df, # Update the state for the download button
283
- download_button: gr.update(visible=True if not results_df.empty else False),
284
- }
285
-
286
- # --- NEW: Function to handle the file download ---
287
- def download_results_as_csv(df):
288
- """Saves the DataFrame to a temporary CSV file and returns its path."""
289
- with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.csv', encoding='utf-8') as tmp:
290
- df.to_csv(tmp.name, index=False)
291
- return gr.update(value=tmp.name, visible=True)
292
-
293
- # --- Event Wiring ---
294
- def handle_toggle_click(button_text, all_choices):
295
- if button_text == "Select All": return gr.update(value=all_choices), gr.update(value="Deselect All")
296
- else: return gr.update(value=[]), gr.update(value="Select All")
297
- def update_button_on_check(selections):
298
- return gr.update(value="Select All") if not selections else gr.update(value="Deselect All")
299
-
300
- submit_event_args = {"fn": handle_generate, "inputs": [topic, provider, openai_key, anthropic_key, google_key], "outputs": [generate_btn] + [comp for pair in category_components for comp in pair]}
301
- generate_btn.click(**submit_event_args)
302
- topic.submit(**submit_event_args)
303
-
304
- toggle_ner_btn.click(fn=handle_toggle_click, inputs=[toggle_ner_btn, gr.State(TRADITIONAL_NER_LABELS)], outputs=[ner_output, toggle_ner_btn])
305
- ner_output.change(fn=update_button_on_check, inputs=[ner_output], outputs=[toggle_ner_btn])
306
-
307
- def create_toggle_handler(cg_component):
308
- # We need a closure to capture the correct cg_component for each button
309
- def handler(button_text):
310
- # Gradio provides the component's choices at runtime, so we can access them here
311
- return handle_toggle_click(button_text, cg_component.choices)
312
- return handler
313
-
314
- for acc, cg, toggle_btn in category_components:
315
- # Note: We pass the component itself to gr.State to get its properties in the handler
316
- toggle_btn.click(
317
- fn=lambda btn_txt, choices: handle_toggle_click(btn_txt, choices),
318
- inputs=[toggle_btn, gr.State(cg.choices)],
319
- outputs=[cg, toggle_btn]
320
- )
321
- cg.change(fn=update_button_on_check, inputs=[cg], outputs=[toggle_btn])
322
-
323
- match_btn.click(
324
- fn=match_entities,
325
- inputs=[text_input, ner_output, custom_labels, threshold_slider] + [cg for acc, cg, btn in category_components],
326
- # --- CHANGE: Added new state and download components to outputs ---
327
- outputs=[match_btn, matched_output, detailed_results_output, debug_output, download_button, download_file, text_state, dataframe_state]
328
  )
 
 
 
 
 
 
329
 
330
- # --- NEW: Wire up the dynamic update and download events ---
331
- matched_output.change(
332
- fn=update_detailed_results,
333
- inputs=[matched_output, text_state],
334
- outputs=[detailed_results_output, dataframe_state, download_button]
335
- )
336
 
337
- download_button.click(
338
- fn=download_results_as_csv,
339
- inputs=[dataframe_state],
340
- outputs=[download_file]
 
 
 
341
  )
342
 
343
  demo.launch(share=True, debug=True)
 
1
  # πŸ“š Install dependencies
2
  # Make sure to run this in your environment if you haven't already
3
+ # !pip install openai anthropic google-generativeai gradio transformers torch gliner --quiet
4
 
5
  # βš™οΈ Imports
6
  import openai
 
8
  import google.generativeai as genai
9
  import gradio as gr
10
  from gliner import GLiNER
11
+ import traceback
12
  from collections import defaultdict, Counter
13
+ import numpy as np # For calculating average score
14
  import os
 
 
15
 
16
  # 🧠 Supported models and their providers
17
  MODEL_OPTIONS = {
 
25
 
26
  # --- Load the model only once at startup ---
27
  try:
28
+ print("Loading AI Detective (GLiNER model)... This may take a moment.")
29
  gliner_model = GLiNER.from_pretrained(GLINER_MODEL_NAME)
30
+ print("AI Detective loaded successfully.")
31
  except Exception as e:
32
  print(f"FATAL ERROR: Could not load GLiNER model. The app will not be able to find entities. Error: {e}")
33
  gliner_model = None
34
 
35
+ # 🧠 Prompt for the Creative AI to generate label ideas
36
+ HIERARCHICAL_PROMPT_TEMPLATE = """
37
+ You are a helpful research assistant. For the historical topic: **"{topic}"**, your job is to suggest a research framework.
38
+
39
+ **Instructions:**
40
+ 1. First, think of 4-6 **Conceptual Categories** that are useful for analyzing this topic (e.g., 'Forms of Protest', 'Key Demands'). These will become the labels.
41
+ 2. For each category, list specific **Examples** someone could search for in a text.
42
+ 3. **Crucial Rule for Labels:** Use the most basic, fundamental form (e.g., `Petition`, not `Political Petition`).
43
+
44
+ **Output Format:**
45
+ Use Markdown. Each category must be a Level 3 Header (###), followed by a comma-separated list of its examples.
46
+
47
+ ### Example Category 1
48
+ - Example A, Example B, Example C
49
+ ### Example Category 2
50
+ - Example D, Example E
51
+ """
52
+
53
+ # 🧠 Generator Function (The "Creative Brain")
54
+ def generate_from_prompt(prompt, provider, key_dict):
55
+ provider_id = MODEL_OPTIONS.get(provider)
56
+ api_key = key_dict.get(f"{provider_id}_key")
57
+ if not api_key:
58
+ raise ValueError(f"API key for {provider} not found.")
59
+
60
+ if provider_id == "openai":
61
+ client = openai.OpenAI(api_key=api_key)
62
+ response = client.chat.completions.create(model="gpt-4o", messages=[{"role": "user", "content": prompt}], temperature=0.2)
63
+ return response.choices[0].message.content.strip()
64
+ elif provider_id == "anthropic":
65
+ client = anthropic.Anthropic(api_key=api_key)
66
+ response = client.messages.create(model="claude-3-opus-20240229", max_tokens=1024, messages=[{"role": "user", "content": prompt}])
67
+ return response.content[0].text.strip()
68
+ elif provider_id == "google":
69
+ genai.configure(api_key=api_key)
70
+ model = genai.GenerativeModel('gemini-1.5-pro-latest')
71
+ response = model.generate_content(prompt)
72
+ return response.text.strip()
73
+ return ""
74
+
75
+ # --- UI Definitions ---
76
+
77
+ # A list of standard, common labels the user can always choose from
78
+ STANDARD_LABELS = [
79
+ "PERSON", "ORGANIZATION", "LOCATION", "COUNTRY", "CITY", "STATE",
80
+ "NATIONALITY", "GROUP", "DATE", "EVENT", "LAW", "LEGAL_DOCUMENT",
81
+ "PRODUCT", "FACILITY", "WORK_OF_ART", "LANGUAGE", "TIME", "PERCENTAGE",
82
+ "MONEY", "CURRENCY", "QUANTITY", "ORDINAL_NUMBER", "CARDINAL_NUMBER"
83
+ ]
84
+
85
+ MAX_CATEGORIES = 8 # The maximum number of AI-suggested categories to show
86
+
87
+ with gr.Blocks(title="Smart Text Analyzer", css=".prose { word-break: break-word; }") as demo:
88
+ gr.Markdown("# Smart Text Analyzer")
89
+ gr.Markdown(
90
+ """
91
+ Welcome! Paste your text below to automatically find and highlight key information. It's like having two smart assistants read your document for you.
92
+
93
+ ### How It Works: Two Brains are Better Than One!
94
+ We use two different types of AI to give you the best results.
95
+
96
+ 🧠 **1. The Creative Brain (Generative AI - like GPT)**
97
+ This AI is a brainstormer. It reads your topic to understand the context, then *imagines* and *suggests* useful labels that fit your document. It helps you discover what to look for!
98
+
99
+ πŸ•΅οΈ **2. The Detective (Extractive AI - GLiNER)**
100
+ This AI is a precise detective. Once you give it a list of labels, it meticulously scans the text and *pulls out* (extracts) the exact words that match. It's fantastic at finding specific information with high accuracy.
101
+ """
102
+ )
103
+
104
+ gr.Markdown("--- \n## Step 1: Get Label Ideas from the Creative AI")
105
  with gr.Row():
106
+ topic = gr.Textbox(label="Enter a Topic", placeholder="e.g., The Chartist Movement, The Protestant Reformation")
107
+ provider = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Choose Creative AI Model")
108
  with gr.Row():
109
  openai_key = gr.Textbox(label="OpenAI API Key", type="password")
110
  anthropic_key = gr.Textbox(label="Anthropic API Key", type="password")
111
  google_key = gr.Textbox(label="Google API Key", type="password")
 
112
 
113
+ generate_btn = gr.Button("Generate Label Suggestions", variant="primary")
114
+
115
+ gr.Markdown("--- \n## Step 2: Build Your Search & Analyze Text")
116
+ gr.Markdown(
117
+ """
118
+ ### What are Entities or Labels?
119
+ Think of them as special highlighters! They find and color-code specific types of information in your text, like `PERSON`, `DATE`, `LOCATION`, or custom things you define.
120
+ """
121
+ )
122
+
123
+ gr.Markdown("#### 1. Review AI-Suggested Labels")
124
+ gr.Markdown("The AI's suggestions appear below. Uncheck any you don't want.")
125
+
126
+ dynamic_components = []
127
  with gr.Column():
128
  for i in range(MAX_CATEGORIES):
129
+ with gr.Accordion(f"Suggested Label Category {i+1}", visible=False) as acc:
130
  with gr.Row():
131
+ # The CheckboxGroup holds the actual labels (e.g., "Protest", "Petition")
132
+ cg = gr.CheckboxGroup(label="Labels in this category", interactive=True, container=False, scale=4)
133
+ deselect_btn = gr.Button("Deselect All", size="sm", scale=1, min_width=80)
134
+ dynamic_components.append((acc, cg, deselect_btn))
135
+
136
+ gr.Markdown("#### 2. Include Standard Labels (Optional)")
137
  with gr.Group():
138
+ standard_labels_checkbox = gr.CheckboxGroup(choices=STANDARD_LABELS, value=STANDARD_LABELS, label="Standard Entity Labels", info="Common categories like people, places, and dates.")
139
+ with gr.Row():
140
+ select_all_std_btn = gr.Button("Select All", size="sm")
141
+ deselect_all_std_btn = gr.Button("Deselect All", size="sm")
142
+
143
+
144
+ gr.Markdown("#### 3. Add Your Own Custom Labels (Optional)")
145
  with gr.Group():
146
+ custom_labels_textbox = gr.Textbox(label="Enter Custom Labels (comma-separated)", placeholder="e.g., Technology, Weapon, Secret Society...")
147
+
148
+ gr.Markdown("--- \n## Step 3: Analyze Your Document")
149
+ threshold_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold", info="Controls how strict the AI Detective is. Lower to find more matches. Higher for fewer, more precise matches.")
150
+ text_input = gr.Textbox(label="Paste Your Full Text Here for Analysis", lines=10, placeholder="Paste a historical document, an article, or a chapter...")
151
+ analyze_btn = gr.Button("Analyze Text & Find Entities", variant="primary")
152
+
153
+ analysis_status = gr.Markdown(visible=False) # For the "Analyzing..." message
154
+
155
+ gr.Markdown("--- \n## Step 4: Review Your Results")
156
+ gr.Markdown(
157
+ """
158
+ ✨ **Pro Tip: Create Your Own Labels!**
159
+ Did our AI miss something? In the **"Highlighted Text"** view below, simply **click and drag to highlight any piece of text**. A small box will appear, allowing you to name and add your own custom label!
160
+ """
161
+ )
162
 
 
 
 
 
 
 
163
  with gr.Tabs():
164
  with gr.TabItem("Highlighted Text"):
165
+ highlighted_text_output = gr.HighlightedText(label="Found Entities", interactive=True)
 
 
 
 
166
  with gr.TabItem("Detailed Results"):
167
+ detailed_results_output = gr.Markdown(label="List of Found Entities by Label")
 
 
 
 
 
 
 
 
 
 
168
  with gr.TabItem("Debug Info"):
169
  debug_output = gr.Textbox(label="Extraction Log", interactive=False, lines=8)
170
 
171
  # --- Backend Functions ---
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  def handle_generate(topic, provider, openai_k, anthropic_k, google_k):
174
+ yield {
175
+ generate_btn: gr.update(value="🧠 Generating suggestions...", interactive=False)
176
+ }
177
+
178
  try:
179
+ key_dict = {
180
+ "openai_key": os.environ.get("OPENAI_API_KEY", openai_k),
181
+ "anthropic_key": os.environ.get("ANTHROPIC_API_KEY", anthropic_k),
182
+ "google_key": os.environ.get("GOOGLE_API_KEY", google_k)
183
+ }
184
+
185
  provider_id = MODEL_OPTIONS.get(provider)
186
+ if not topic or not provider or not key_dict.get(f"{provider_id}_key"):
187
+ raise gr.Error("Topic, Provider, and the correct API Key are required.")
188
+
189
  prompt = HIERARCHICAL_PROMPT_TEMPLATE.format(topic=topic)
190
  raw_framework = generate_from_prompt(prompt, provider, key_dict)
191
+
192
+ # This parsing is simplified for the new structure
193
  framework = defaultdict(list)
194
  current_category = None
195
  for line in raw_framework.split('\n'):
196
  line = line.strip()
197
+ if line.startswith("###"):
198
+ current_category = line.replace("###", "").strip()
199
+ elif line.startswith("-") and current_category:
200
+ entities = line.replace("-", "").strip()
201
+ framework[current_category].extend([e.strip() for e in entities.split(',') if e.strip()])
202
+
203
+ if not framework:
204
+ raise gr.Error("AI failed to generate categories. Please try again or rephrase your topic.")
205
+
206
  updates = {}
207
  categories = list(framework.items())
208
  for i in range(MAX_CATEGORIES):
209
+ accordion_comp, checkbox_comp, button_comp = dynamic_components[i]
210
  if i < len(categories):
211
+ category_name, entities = categories[i]
212
+ # The labels are the entities themselves, grouped by the category name
213
  sorted_entities = sorted(list(set(entities)))
214
+ updates[accordion_comp] = gr.update(label=f"Category: {category_name}", visible=True)
215
+ updates[checkbox_comp] = gr.update(choices=sorted_entities, value=sorted_entities, label="Suggested Labels", visible=True)
216
+ updates[button_comp] = gr.update(visible=True)
217
  else:
218
  updates[accordion_comp] = gr.update(visible=False)
219
  updates[checkbox_comp] = gr.update(visible=False)
220
+ updates[button_comp] = gr.update(visible=False)
221
+
222
+ updates[generate_btn] = gr.update(value="Generate Label Suggestions", interactive=True)
223
  yield updates
224
  except Exception as e:
225
+ yield {generate_btn: gr.update(value="Generate Label Suggestions", interactive=True)}
226
  raise gr.Error(str(e))
227
 
228
+ def analyze_text_and_find_entities(text, standard_labels, custom_label_text, threshold, *suggested_labels_from_groups):
229
+ # --- 1. Show Progress to User ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  yield {
231
+ analyze_btn: gr.update(value="πŸ•΅οΈ Analyzing...", interactive=False),
232
+ analysis_status: gr.update(value="Our AI Detective is scanning your text. This may take a moment...", visible=True),
233
+ highlighted_text_output: None,
234
  detailed_results_output: None,
235
+ debug_output: "Starting analysis..."
 
236
  }
 
237
 
238
+ debug_info = []
239
+ if gliner_model is None:
240
+ raise gr.Error("GLiNER model failed to load at startup. Cannot analyze text. Please check logs.")
241
+
242
+ # --- 2. Collect All Labels from UI ---
243
  labels_to_use = set()
244
+ # Add labels from the dynamically generated suggestion groups
245
+ for group in suggested_labels_from_groups:
246
  if group: labels_to_use.update(group)
247
+ # Add labels from the standard list
248
+ if standard_labels: labels_to_use.update(standard_labels)
249
+ # Add labels from the custom textbox
250
  custom = {l.strip() for l in custom_label_text.split(',') if l.strip()}
251
  if custom: labels_to_use.update(custom)
 
 
252
 
253
+ final_labels = sorted(list(labels_to_use))
254
+ debug_info.append(f"🧠 Searching for {len(final_labels)} unique labels.")
255
+ debug_info.append(f"βš™οΈ Confidence Threshold: {threshold}")
256
+
257
  if not text or not final_labels:
258
+ yield {
259
+ analyze_btn: gr.update(value="Analyze Text & Find Entities", interactive=True),
260
+ analysis_status: gr.update(visible=False),
261
+ highlighted_text_output: {"text": text, "entities": []},
262
+ detailed_results_output: "Please provide text and select at least one label to search for.",
263
+ debug_output: "Analysis stopped: No text or no labels provided."
264
+ }
265
  return
266
+
267
+ # --- 3. Run the GLiNER Model (The "Detective") ---
268
  all_entities = []
269
+ # Process text in chunks to handle very long documents
270
+ chunk_size, overlap = 1024, 100
271
+ for i in range(0, len(text), chunk_size - overlap):
272
  chunk = text[i : i + chunk_size]
273
  chunk_entities = gliner_model.predict_entities(chunk, final_labels, threshold=threshold)
274
  for ent in chunk_entities:
275
+ ent['start'] += i
276
+ ent['end'] += i
277
  all_entities.append(ent)
278
 
279
+ # Deduplicate entities that might span across chunk overlaps
280
  unique_entities = [dict(t) for t in {tuple(d.items()) for d in all_entities}]
281
+ debug_info.append(f"πŸ“Š Found {len(unique_entities)} raw entity mentions.")
282
+
283
+ # --- 4. Prepare Highlighted Text Output ---
284
+ highlighted_output_data = {
285
+ "text": text,
286
+ "entities": [{"start": ent["start"], "end": ent["end"], "label": ent["label"]} for ent in unique_entities]
287
+ }
288
+
289
+ # --- 5. Prepare Detailed Table-Based Results ---
290
+ aggregated_matches = defaultdict(lambda: {'count': 0, 'scores': [], 'original_casing': ''})
291
 
292
+ for ent in unique_entities:
293
+ match_text = text[ent['start']:ent['end']]
294
+ # Use a key of (label, lowercase_text) to group similar items
295
+ key = (ent['label'], match_text.lower())
296
+
297
+ aggregated_matches[key]['count'] += 1
298
+ aggregated_matches[key]['scores'].append(ent['score'])
299
+ # Store the first-seen casing of the text
300
+ if not aggregated_matches[key]['original_casing']:
301
+ aggregated_matches[key]['original_casing'] = match_text
302
 
303
+ # Group aggregated results by label for final display
304
+ results_by_label = defaultdict(list)
305
+ for (label, _), data in aggregated_matches.items():
306
+ avg_score = np.mean(data['scores'])
307
+ results_by_label[label].append({
308
+ 'text': data['original_casing'],
309
+ 'count': data['count'],
310
+ 'avg_score': avg_score
311
+ })
312
 
313
+ # --- 6. Build the Markdown String for the Detailed Table ---
314
+ markdown_string = ""
315
+ for label, items in sorted(results_by_label.items()):
316
+ markdown_string += f"### {label}\n"
317
+ markdown_string += "| Text Found | Instances Found | Avg. Confidence Score* |\n"
318
+ markdown_string += "|------------|-----------------|--------------------------|\n"
319
+
320
+ # Sort items by count (most frequent first)
321
+ for item in sorted(items, key=lambda x: x['count'], reverse=True):
322
+ markdown_string += f"| {item['text']} | {item['count']} | {item['avg_score']:.2f} |\n"
323
+ markdown_string += "\n"
324
+
325
+ if not markdown_string:
326
+ markdown_string = "No entities found. Try lowering the confidence threshold or changing your labels."
327
+ else:
328
+ markdown_string += "\n---\n<small><i>*<b>Confidence Score:</b> How sure the AI Detective (GLiNER) is that it found the correct label (1.00 = 100% certain). The score shown is the average across all instances of that text.</i></small>"
329
+
330
+ debug_info.append("βœ… Analysis complete.")
331
+
332
+ # --- 7. Yield Final Results to UI ---
333
  yield {
334
+ analyze_btn: gr.update(value="Analyze Text & Find Entities", interactive=True),
335
+ analysis_status: gr.update(visible=False),
336
+ highlighted_text_output: highlighted_output_data,
337
+ detailed_results_output: markdown_string,
338
+ debug_output: "\n".join(debug_info)
 
 
339
  }
340
 
341
+ # --- Wire up UI events ---
342
+ generate_btn.click(
343
+ fn=handle_generate,
344
+ inputs=[topic, provider, openai_key, anthropic_key, google_key],
345
+ outputs=[generate_btn] + [comp for pair in dynamic_components for comp in pair]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  )
347
+
348
+ # Functions for Select/Deselect All buttons
349
+ def deselect_all():
350
+ return gr.update(value=[])
351
+ def select_all(choices):
352
+ return gr.update(value=choices)
353
 
354
+ deselect_all_std_btn.click(fn=deselect_all, inputs=None, outputs=[standard_labels_checkbox])
355
+ select_all_std_btn.click(lambda: select_all(STANDARD_LABELS), inputs=None, outputs=[standard_labels_checkbox])
 
 
 
 
356
 
357
+ for _, cg, btn in dynamic_components:
358
+ btn.click(fn=deselect_all, inputs=None, outputs=[cg])
359
+
360
+ analyze_btn.click(
361
+ fn=analyze_text_and_find_entities,
362
+ inputs=[text_input, standard_labels_checkbox, custom_labels_textbox, threshold_slider] + [cg for acc, cg, btn in dynamic_components],
363
+ outputs=[analyze_btn, analysis_status, highlighted_text_output, detailed_results_output, debug_output]
364
  )
365
 
366
  demo.launch(share=True, debug=True)