brandonmusic commited on
Commit
efaf209
·
verified ·
1 Parent(s): b88cb65

Update App.py

Browse files
Files changed (1) hide show
  1. App.py +490 -405
App.py CHANGED
@@ -1,31 +1,18 @@
1
- # script.py
2
- # This is the updated main script. Copy-paste this over your existing script.py.
3
- # Changes:
4
- # - Fixed any potential issues with prompt_lower by ensuring consistent use of prompt.lower() (though it was already correct).
5
- # - Split route_model responsibilities: Moved retrieval functions to retrieval.py, prompt building to prompt_builder.py, post-processing to post_processing.py.
6
- # - Imports: Add 'from retrieval import *', 'from prompt_builder import *', 'from post_processing import *' after your existing imports.
7
- # - For synchronous loads: hf_hub_download already caches, so subsequent calls are fast. To avoid first-time blocking, I've added a background preloader thread that pre-downloads a limited number of clusters (e.g., first 10) at startup. If you have many clusters, create a separate script (see below) to pre-download all offline.
8
- # - No need for a separate script unless you want to pre-download ALL clusters (which could be storage-intensive). If yes, see the optional pre_download_clusters.py below.
9
- # - In route_model, now calls the split functions.
10
- # - Passed necessary globals (e.g., cap_dataset, cap_id_to_index) to semantic_search.
11
- # - Ensured municipal_embeddings is loaded (assume it's global).
12
- # - Added threading for preloading some clusters at startup.
13
-
14
- import gradio as gr
15
- from openai import OpenAI
16
  import requests
17
  import os
18
  import logging
19
  from datetime import datetime
20
  import pdfplumber
21
- from googleapiclient.discovery import build
 
 
22
  import re
23
- from datasets import load_dataset, Dataset, load_from_disk
24
  from sentence_transformers import SentenceTransformer
25
  import torch
26
  import numpy as np
27
  import shutil
28
- import pyarrow.parquet as pq
29
  from huggingface_hub import hf_hub_download
30
  import pickle
31
  import faiss
@@ -33,340 +20,221 @@ import threading
33
  import subprocess
34
  from task_processing import process_task_response
35
  from gpt_helpers import ask_gpt41_mini
 
 
 
 
 
 
 
 
36
 
37
- # New imports for split modules
38
- from retrieval import *
39
- from prompt_builder import *
40
- from post_processing import *
41
 
42
  os.environ["HF_HOME"] = "/data/.huggingface"
43
- # Add or update this section in script.py
44
- # Ensure this is placed after imports but before any dataset loading or function definitions
45
 
46
- from huggingface_hub import login
 
 
 
 
47
 
48
- # Load HF token for SaulLM endpoint and gated repos
 
49
  hf_token = os.environ.get("HF_TOKEN", "")
50
- if not hf_token:
51
- logger.warning("HF_TOKEN not set; SaulLM endpoint may require authentication and gated repos may not be accessible.")
52
-
53
- # Authenticate for gated Hugging Face repos (e.g., for centroids download)
54
  if hf_token:
55
  login(hf_token)
56
- logger.info("Authenticated with Hugging Face token for gated repos.")
57
  else:
58
- logger.warning("No HF_TOKEN; may fail to access gated repos like Caselaw_Access_Project_embeddings.")
59
 
60
  # Check environment variables
61
- try:
62
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "Missing")
63
- GOOGLE_SEARCH_API = os.environ.get("GOOGLE_SEARCH_API", "Missing") # This is now treated as CSE ID (cx)
64
- GOOGLE_CUSTOM_SEARCH_API_KEY = os.environ.get("GOOGLE_CUSTOM_SEARCH_API_KEY", "Missing") # New: API key (developerKey)
65
- if OPENAI_API_KEY == "Missing" or GOOGLE_CUSTOM_SEARCH_API_KEY == "Missing" or GOOGLE_SEARCH_API == "Missing":
66
- raise KeyError("API keys not set")
67
- logger.info(f"OpenAI API Key starts with: {OPENAI_API_KEY[:10]}...")
68
- logger.info("API keys loaded successfully")
69
- except KeyError as e:
70
- logger.error(f"Missing environment variable: {str(e)}")
71
- raise EnvironmentError(f"Required secrets OPENAI_API_KEY, GOOGLE_CUSTOM_SEARCH_API_KEY, and GOOGLE_SEARCH_API must be set in Hugging Face Space Secrets")
72
- def download_file_if_not_exists(url, save_path):
73
- if not os.path.exists(save_path):
74
- try:
75
- response = requests.get(url)
76
- response.raise_for_status() # Raise error if not 200
77
- with open(save_path, 'wb') as f:
78
- f.write(response.content)
79
- logger.info(f"Downloaded and saved file to {save_path}")
80
- except Exception as e:
81
- logger.error(f"Failed to download from {url}: {str(e)}")
82
-
83
- # Download the centroids file if not present
84
- centroid_url = "https://huggingface.co/datasets/laion/Caselaw_Access_Project_embeddings/blob/main/TeraflopAI___Caselaw_Access_Project_centroids.parquet"
85
- centroid_path = "TeraflopAI___Caselaw_Access_Project_centroids.parquet"
86
- download_file_if_not_exists(centroid_url, centroid_path)
87
- # Load HF token for SaulLM endpoint
88
- hf_token = os.environ.get("HF_TOKEN", "")
89
- if not hf_token:
90
- logger.warning("HF_TOKEN not set; SaulLM endpoint may require authentication")
91
-
92
- import requests
93
-
94
-
95
- # Initialize OpenAI client
96
- openai_client = OpenAI(api_key=OPENAI_API_KEY)
97
-
98
- # SaulLM endpoint
99
- SAUL_ENDPOINT = "https://l4tuv4j9bu616t5x.us-east-1.aws.endpoints.huggingface.cloud"
100
-
101
- # Persistent storage path for dataset
102
- LOCAL_PATH = "/data/cap_dataset"
103
- dataset_info_path = os.path.join(LOCAL_PATH, 'dataset_info.json')
104
- if os.path.exists(dataset_info_path):
105
- cap_dataset = load_from_disk(LOCAL_PATH)
106
  else:
107
- try:
108
- cap_dataset = load_dataset("TeraflopAI/Caselaw-Access-Project", split="train")
109
- cap_dataset.save_to_disk(LOCAL_PATH)
110
- except Exception as e:
111
- logger.error(f"Dataset download/save failed: {str(e)}")
112
- if os.path.exists(LOCAL_PATH):
113
- shutil.rmtree(LOCAL_PATH) # Clean up partial save
114
- raise
115
-
116
- # Persistent storage paths for municipal
117
- MUNICIPAL_EMBEDDINGS_PATH = "/data/municipal_embeddings"
118
- MUNICIPAL_HTML_PATH = "/data/municipal_html"
119
- MUNICIPAL_CITATION_PATH = "/data/municipal_citation"
120
- MUNICIPAL_FAISS_INDEX_PATH = "/data/municipal_faiss.index"
121
- MUNICIPAL_CID_TO_HTML_PATH = "/data/cid_to_html.pkl"
122
- MUNICIPAL_CID_TO_CITATION_PATH = "/data/cid_to_citation.pkl"
123
-
124
- # Only trigger once to avoid re-downloading
125
- def run_municipal_embedding_script_once():
126
- marker_path = "/data/municipal_embeddings_done.txt"
127
- if not os.path.exists(marker_path):
128
- try:
129
- subprocess.run(["python", "prepare_municipal_embeddings.py"], check=True)
130
- with open(marker_path, "w") as f:
131
- f.write("done")
132
- logger.info(" Municipal embedding preparation complete.")
133
- except Exception as e:
134
- logger.error(f"❌ Error running prepare_municipal_embeddings.py: {e}")
135
-
136
- # Only trigger once to avoid re-building
137
- def run_build_municipal_faiss_once():
138
- marker_path = "/data/municipal_faiss_done.txt"
139
- if not os.path.exists(marker_path):
140
- try:
141
- subprocess.run(["python", "build_municipal_faiss.py"], check=True)
142
- with open(marker_path, "w") as f:
143
- f.write("done")
144
- logger.info("✅ Municipal FAISS index build complete.")
145
- except Exception as e:
146
- logger.error(f" Error running build_municipal_faiss.py: {e}")
147
-
148
- # Launch in background AFTER app boots
149
- # Run preparation scripts synchronously if not done
150
- run_municipal_embedding_script_once()
151
- run_build_municipal_faiss_once()
152
-
153
- # Load municipal embeddings dataset
154
- if os.path.exists(MUNICIPAL_EMBEDDINGS_PATH):
155
- municipal_embeddings = load_from_disk(MUNICIPAL_EMBEDDINGS_PATH)
156
- else:
157
- logger.error("Municipal embeddings not found. Ensure prepare_municipal_embeddings.py ran successfully.")
158
- municipal_embeddings = None # Fallback or error handling
159
-
160
- # Load municipal html dataset
161
- if os.path.exists(MUNICIPAL_HTML_PATH):
162
- municipal_html = load_from_disk(MUNICIPAL_HTML_PATH)
163
- else:
164
- logger.error("Municipal html not found. Ensure prepare_municipal_embeddings.py ran successfully.")
165
- municipal_html = None
166
-
167
- # Load municipal citation dataset
168
- if os.path.exists(MUNICIPAL_CITATION_PATH):
169
- municipal_citation = load_from_disk(MUNICIPAL_CITATION_PATH)
170
- else:
171
- logger.error("Municipal citation not found. Ensure prepare_municipal_embeddings.py ran successfully.")
172
- municipal_citation = None
173
-
174
- # Precompute CID to index mapping for CAP dataset
175
- cap_id_to_index = {doc['cid']: i for i, doc in enumerate(cap_dataset) if 'cid' in doc}
176
-
177
- # Preload some clusters in background (e.g., clusters 0-9)
178
- def preload_clusters():
179
- for cluster_id in range(10): # Adjust range as needed
180
- try:
181
- load_cluster_vectors(cluster_id, model="gte-large")
182
- logger.info(f"Preloaded cluster {cluster_id}")
183
- except Exception as e:
184
- logger.error(f"Preload failed for cluster {cluster_id}: {e}")
185
 
186
- threading.Thread(target=preload_clusters).start()
 
187
 
188
  # State dictionary for jurisdiction
189
  STATES = {
190
- "AL": "Alabama",
191
- "AK": "Alaska",
192
- "AZ": "Arizona",
193
- "AR": "Arkansas",
194
- "CA": "California",
195
- "CO": "Colorado",
196
- "CT": "Connecticut",
197
- "DE": "Delaware",
198
- "FL": "Florida",
199
- "GA": "Georgia",
200
- "HI": "Hawaii",
201
- "ID": "Idaho",
202
- "IL": "Illinois",
203
- "IN": "Indiana",
204
- "IA": "Iowa",
205
- "KS": "Kansas",
206
- "KY": "Kentucky",
207
- "LA": "Louisiana",
208
- "ME": "Maine",
209
- "MD": "Maryland",
210
- "MA": "Massachusetts",
211
- "MI": "Michigan",
212
- "MN": "Minnesota",
213
- "MS": "Mississippi",
214
- "MO": "Missouri",
215
- "MT": "Montana",
216
- "NE": "Nebraska",
217
- "NV": "Nevada",
218
- "NH": "New Hampshire",
219
- "NJ": "New Jersey",
220
- "NM": "New Mexico",
221
- "NY": "New York",
222
- "NC": "North Carolina",
223
- "ND": "North Dakota",
224
- "OH": "Ohio",
225
- "OK": "Oklahoma",
226
- "OR": "Oregon",
227
- "PA": "Pennsylvania",
228
- "RI": "Rhode Island",
229
- "SC": "South Carolina",
230
- "SD": "South Dakota",
231
- "TN": "Tennessee",
232
- "TX": "Texas",
233
- "UT": "Utah",
234
- "VT": "Vermont",
235
- "VA": "Virginia",
236
- "WA": "Washington",
237
- "WV": "West Virginia",
238
- "WI": "Wisconsin",
239
- "WY": "Wyoming",
240
- "Federal": "Federal",
241
- "All States": "All States",
242
- "Other": "Other States"
243
  }
 
244
 
245
- def route_model(prompt, task_type, files=None, search_web=False, jurisdiction="KY"):
246
- logger.info(f"Routing prompt: {prompt}, Task: {task_type}, Web Search: {search_web}, Jurisdiction: {jurisdiction}")
247
-
248
- rag_context = ""
249
- if task_type in ["case_law", "irac", "statute"]:
250
- cap_results = semantic_search(prompt, top_k=5)
251
- municipal_results = municipal_search(prompt, top_k=5)
252
- combined_results = cap_results + municipal_results
253
-
254
- # Filter by jurisdiction if specified (e.g., "KY" for Kentucky)
255
- if jurisdiction and jurisdiction != "All States":
256
- state_name = STATES.get(jurisdiction, "")
257
- state_code = jurisdiction # e.g., "KY"
258
- combined_results = [r for r in combined_results if state_code in r['citation'] or state_name in r['citation'] or state_code in r['name'] or state_name in r['name']]
259
-
260
- if combined_results:
261
- rag_context = "Retrieved legal authorities (case law and statutes):\n" + "\n".join([f"{i+1}. [{auth.get('source', 'Unknown')}] {auth['name']}, {auth['citation']}: \"{auth['snippet']}\"" for i, auth in enumerate(combined_results)])
262
-
263
- prompt = f"User prompt: {prompt}\n\n{rag_context}"
264
-
265
- saul_response = ask_saul(prompt, task_type, jurisdiction)
266
-
267
- # Task-specific processing (existing code)
268
- saul_response = process_task_response(task_type, saul_response, prompt, jurisdiction)
269
-
270
- if search_web:
271
- web_data = google_search(prompt)
272
- saul_response = f"Google Search results: {web_data}\n{saul_response}"
273
-
274
- editor_prompt = build_editor_prompt(prompt, task_type, jurisdiction, saul_response, rag_context)
275
-
276
- final_response = ask_gpt4o(editor_prompt)
277
-
278
- final_response = ground_statutes(final_response, jurisdiction)
279
-
280
- return final_response
281
-
282
- def ask_saul(messages, task_type, jurisdiction):
283
  try:
284
- headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {}
 
 
 
 
285
  payload = {
286
  "messages": messages,
287
- "parameters": {
288
- "max_length": 512,
289
- "temperature": 0.3
 
 
 
290
  }
291
  }
292
- logger.info(f"SaulLM payload: messages length={len(messages)}, max_length={payload['parameters']['max_length']}")
293
- response = requests.post(SAUL_ENDPOINT, headers=headers, json=payload)
 
294
  response.raise_for_status()
295
- result = response.json()
296
- if isinstance(result, dict) and "choices" in result:
297
- return result["choices"][0].get("message", {}).get("content", "[No response from SaulLM]")
298
- elif isinstance(result, list) and result:
299
- return result[0].get("generated_text", "[No response from SaulLM]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  else:
301
- return result.get("generated_text", "[No response from SaulLM]")
302
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  except Exception as e:
304
- logger.error(f"SaulLM error: {str(e)}")
305
- return "SaulLM service unavailable. Using fallback response."
306
-
307
- def ask_gpt41_mini(prompt, jurisdiction):
 
 
 
 
 
308
  try:
309
- response = openai_client.chat.completions.create(
310
- model="gpt-4", # Placeholder, replace with fine-tuned model
311
- messages=[
312
- {"role": "system", "content": (
313
- f"You are a legal assistant drafting documents for {jurisdiction} jurisdiction. "
314
- "Always quote directly from retrieved case law. Use full case names and citations (e.g., 'Smith v. Jones, 123 S.W.3d 456 (Ky. 2005)'). "
315
- "Prioritize high quote density and include facts from those cases when applying them. Use IRAC structure. Do not paraphrase available holdings."
316
- )},
317
- {"role": "user", "content": prompt}
318
- ],
319
- temperature=0.3,
320
- max_tokens=8192
321
- )
322
- return response.choices[0].message.content
323
- except Exception as e:
324
- logger.error(f"GPT-4.1 Mini error: {str(e)}")
325
- return f"[GPT-4.1 Mini Error] {str(e)}"
326
-
327
- def ask_gpt4o(prompt):
328
- try:
329
- response = openai_client.chat.completions.create(
330
- model="gpt-4o",
331
- messages=[
332
- {"role": "system", "content": (
333
- "You are the final editor for a legal research assistant. Polish and organize the output into clear IRAC format. "
334
- "Ensure high quote density from retrieved authorities and include relevant facts from the cited cases. "
335
- "Maintain accurate citations. Do not paraphrase legal holdings when direct quotes are available."
336
- )},
337
- {"role": "user", "content": prompt}
338
- ],
339
- temperature=0.3,
340
- max_tokens=16384
341
- )
342
- return response.choices[0].message.content
343
- except Exception as e:
344
- logger.error(f"GPT-4o error: {str(e)}")
345
- return f"[GPT-4o Error] {str(e)}"
346
-
347
- def extract_text_from_pdf(file_path):
348
- try:
349
- with pdfplumber.open(file_path) as pdf:
350
- text = ""
351
- for page in pdf.pages:
352
- text += page.extract_text() or ""
353
- logger.info(f"Extracted text length: {len(text)}")
354
- return text
355
  except Exception as e:
356
- logger.error(f"PDF extraction error: {str(e)}")
357
  return ""
358
 
359
  def classify_prompt(prompt):
360
  prompt_lower = prompt.lower()
361
  if "summarize" in prompt_lower:
362
  return "document_analysis" # Treat summarize as analysis for routing
363
- if any(k in prompt_lower for k in ["irac", "issue", "rule", "analysis", "conclusion"]):
364
  return "irac"
365
  elif any(k in prompt_lower for k in ["case", "precedent", "law"]):
366
  return "case_law"
367
  elif any(k in prompt_lower for k in ["statute", "krs"]):
368
  return "statute"
369
- elif any(k in prompt_lower for k in ["draft", "write", "generate", "petition", "letter", "contract"]):
370
  return "document_creation"
371
  elif any(k in prompt_lower for k in ["review", "summarize", "clause", "red flags"]):
372
  return "document_analysis"
@@ -392,104 +260,321 @@ def classify_prompt(prompt):
392
  return "legal_strategy"
393
  return "general_qa"
394
 
395
- def chat_interface(prompt, files, history, search_web=False, jurisdiction="KY"):
396
- timestamp = datetime.now().strftime("%I:%M %p %m/%d/%Y")
397
- task_type = classify_prompt(prompt)
398
- if files:
399
- file_text = extract_text_from_pdf(files[0]) if files else ""
400
- if "summarize" in prompt.lower():
401
- task_type = "document_analysis"
402
- response = summarize_document(files)
403
- elif "analyze" in prompt.lower():
404
- task_type = "document_analysis"
405
- response = analyze_document(files)
406
- elif "check" in prompt.lower() or "issues" in prompt.lower():
407
- task_type = "document_analysis"
408
- response = check_issues(files)
409
- elif "generate" in prompt.lower() or "draft" in prompt.lower():
410
- task_type = "document_creation"
411
- response = ask_gpt41_mini(prompt + "\nAttached file content: " + file_text, jurisdiction)
 
 
 
 
 
 
 
412
  else:
413
- prompt += "\nAttached file content: " + file_text[:10000]
414
- response = route_model(prompt, task_type, files, search_web, jurisdiction)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  else:
416
- response = route_model(prompt, task_type, files, search_web, jurisdiction)
417
- history.append((f"{prompt} <span style='color: #ECF0F1; font-size: 16px;'>[{timestamp}]</span>",
418
- f"{response} <span style='color: #ECF0F1; font-size: 16px;'>[{timestamp}]</span>"))
419
- return history, history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
- def new_chat():
422
- return [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
  def summarize_document(files):
425
- if files and isinstance(files, list) and files:
426
- file = files[0]
427
- text = extract_text_from_pdf(file)
428
- if text:
429
- summary = ask_gpt4o(f"Summarize the following document: {text[:10000]}") # Limit to avoid token limits
430
- return f"Summary: {summary}"
431
- return "No text extracted from PDF."
432
- return "Please upload a file to summarize."
 
 
 
 
 
 
 
 
 
 
433
 
434
  def analyze_document(files):
435
- if files:
436
- text = extract_text_from_pdf(files[0])
437
- if text:
438
- analysis = ask_gpt4o(f"Analyze the following document for legal issues, risks, or key clauses: {text[:10000]}")
439
- return f"Analysis: {analysis}"
440
- return "No text extracted from PDF."
441
- return "No file uploaded for analysis."
 
 
 
 
 
 
 
 
 
 
 
442
 
443
  def check_issues(files):
444
- if files:
445
- text = extract_text_from_pdf(files[0])
446
- if text:
447
- issues = ask_gpt4o(f"Check for red flags, unusual clauses, or potential issues in this legal document: {text[:10000]}")
448
- return f"Issues: {issues}"
449
- return "No text extracted from PDF."
450
- return "No file uploaded to check."
451
-
452
- def save_conversation(history):
453
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
454
- content = "\n".join([f"User: {msg[0]}\nBot: {msg[1]}\n" for msg in history])
455
- with open(f"conversation_{timestamp}.txt", "w") as f:
456
- f.write(content)
457
- return f"conversation_{timestamp}.txt"
458
-
459
- css = """
460
- # ... (your CSS remains the same, omitted for brevity)
461
- """
462
-
463
- theme = gr.themes.Base(
464
- primary_hue="gray",
465
- secondary_hue="gray",
466
- neutral_hue="gray",
467
- ).set(
468
- body_text_color="#000000",
469
- background_fill_primary="#6D8299" # Slate blue background
470
- )
471
-
472
- with gr.Blocks(css=css, theme=theme, title="VerdictAI - Legal Assistant") as app:
473
- jurisdiction = gr.State("KY")
474
- chatbot = gr.Chatbot(elem_id="chat-container", label="Chat")
475
- msg = gr.Textbox(
476
- placeholder="Ask any legal question, request a draft document, upload a contract for analysis, or search for statutes and cases.\nExamples:\n‘Write a Kentucky will for a single parent with two children.’\n‘Summarize this operating agreement and flag any unusual clauses.’\n‘Find cases on constructive trust involving fraud.’\n‘What does KRS 411.182 mean for comparative fault?’\n‘IRAC analysis: A customer slips on an icy sidewalk outside a store.’",
477
- elem_id="user-input"
478
- )
479
- with gr.Row():
480
- with gr.Column(scale=2, elem_classes=["main-content"]):
481
- chatbot
482
- with gr.Row(elem_id="chat-input"):
483
- msg
484
- file_upload = gr.File(file_count="multiple", file_types=[".pdf", "image", "text"], elem_id="file-upload-main", label="📎 Upload")
485
- btn = gr.Button("Send", elem_id="send-btn")
486
- google_search_btn = gr.Button("Google Search", elem_id="google-search-btn")
487
- save_btn = gr.Button("Save Chat", elem_id="save-btn")
488
- action_dropdown = gr.Dropdown(["Summarize", "Analyze", "Check Issues"], label="File Action")
489
-
490
- btn.click(fn=chat_interface, inputs=[msg, file_upload, chatbot, gr.State(False), jurisdiction], outputs=[chatbot, chatbot])
491
- google_search_btn.click(fn=chat_interface, inputs=[msg, file_upload, chatbot, gr.State(True), jurisdiction], outputs=[chatbot, chatbot])
492
- save_btn.click(save_conversation, inputs=[chatbot], outputs=gr.File())
493
-
494
- logger.info("Gradio app initialized successfully")
495
- app.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import requests
3
  import os
4
  import logging
5
  from datetime import datetime
6
  import pdfplumber
7
+ from docx import Document
8
+ from docx.shared import Pt, Inches
9
+ from docx.enum.text import WD_ALIGN_PARAGRAPH
10
  import re
11
+ from datasets import load_dataset, load_from_disk
12
  from sentence_transformers import SentenceTransformer
13
  import torch
14
  import numpy as np
15
  import shutil
 
16
  from huggingface_hub import hf_hub_download
17
  import pickle
18
  import faiss
 
20
  import subprocess
21
  from task_processing import process_task_response
22
  from gpt_helpers import ask_gpt41_mini
23
+ from retrieval import retrieve_context
24
+ from prompt_builder import build_grok_prompt, build_editor_prompt
25
+ from flask import Flask, request, jsonify, send_from_directory, send_file, Response, stream_with_context
26
+ from werkzeug.utils import secure_filename
27
+ from rank_bm25 import BM25Okapi
28
+ from requests.adapters import HTTPAdapter
29
+ from urllib3.util.retry import Retry
30
+ import json # For safer JSON parsing if needed
31
 
32
+ app = Flask(__name__) # Renamed from app_flask to app for HF Spaces compatibility
 
 
 
33
 
34
  os.environ["HF_HOME"] = "/data/.huggingface"
 
 
35
 
36
+ # Logging setup
37
+ logger = logging.getLogger("app")
38
+ logging.basicConfig(level=logging.INFO)
39
+ logger.info("✅ Logging initialized. Starting app setup.")
40
+ print("App setup starting...") # Fallback print for early debug
41
 
42
+ # Hugging Face authentication
43
+ from huggingface_hub import login
44
  hf_token = os.environ.get("HF_TOKEN", "")
 
 
 
 
45
  if hf_token:
46
  login(hf_token)
47
+ logger.info("Authenticated with Hugging Face token for gated repos.")
48
  else:
49
+ logger.warning("HF_TOKEN not set; gated repos may not be accessible.")
50
 
51
  # Check environment variables
52
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "Missing")
53
+ GOOGLE_SEARCH_API = os.environ.get("GOOGLE_SEARCH_API", "Missing") # CSE ID
54
+ GOOGLE_CUSTOM_SEARCH_API_KEY = os.environ.get("GOOGLE_CUSTOM_SEARCH_API_KEY", "Missing") # API key
55
+ COURT_LISTENER_API_KEY = os.environ.get("Court_Listener_API", "Missing") # Updated to match HF secret name
56
+ if OPENAI_API_KEY == "Missing":
57
+ logger.warning("OPENAI_API_KEY not set; OpenAI features will fail.")
58
+ if GOOGLE_CUSTOM_SEARCH_API_KEY == "Missing" or GOOGLE_SEARCH_API == "Missing":
59
+ logger.warning("Google Search keys not set; search features will fail.")
60
+ if COURT_LISTENER_API_KEY == "Missing":
61
+ logger.warning("Court_Listener_API not set; CourtListener features will fail.")
62
+ logger.info(" API keys checked (with warnings if missing).")
63
+
64
+ # Initialize OpenAI client (only if key present)
65
+ openai_client = None
66
+ if OPENAI_API_KEY != "Missing":
67
+ from openai import OpenAI
68
+ openai_client = OpenAI(api_key=OPENAI_API_KEY)
69
+ logger.info("✅ OpenAI client initialized.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  else:
71
+ logger.warning("Skipping OpenAI client init due to missing key.")
72
+
73
+ # Grok API setup
74
+ GROK_API_URL = "https://api.x.ai/v1/chat/completions"
75
+ GROK_API_TOKEN = "xai-fr0jVd7v8jiuxAQko2rpx1ft7DIK0iQkHQTk0RSFubXxdKm6AUgho4hJnlJ2OINlh82AYJ4GELGqLMSb" # From user
76
+ logger.info("✅ Grok API endpoint and token set.")
77
+
78
+ # Global session for retries
79
+ session = requests.Session()
80
+ retries = Retry(total=3, backoff_factor=1, status_forcelist=[422, 503, 504])
81
+ session.mount('https://', HTTPAdapter(max_retries=retries))
82
+
83
+ # Lazy-load CAP dataset to avoid startup issues
84
+ def get_cap_dataset():
85
+ if not hasattr(get_cap_dataset, 'dataset') or get_cap_dataset.dataset is None:
86
+ from datasets import load_from_disk # Lazy import
87
+ LOCAL_PATH = "/data/cap_dataset"
88
+ if os.path.exists(os.path.join(LOCAL_PATH, 'dataset_info.json')):
89
+ try:
90
+ get_cap_dataset.dataset = load_from_disk(LOCAL_PATH)
91
+ logger.info("✅ Lazy-loaded CAP dataset from /data/cap_dataset.")
92
+ except Exception as e:
93
+ logger.error(f"Failed to load CAP dataset: {str(e)}")
94
+ get_cap_dataset.dataset = None
95
+ else:
96
+ logger.error("CAP dataset not found at /data/cap_dataset. Ensure it’s preloaded.")
97
+ get_cap_dataset.dataset = None
98
+ return get_cap_dataset.dataset
99
+
100
+ get_cap_dataset.dataset = None
101
+ logger.info("✅ CAP dataset lazy-loader defined.")
102
+
103
+ # Lazy-compute CID to index mapping for CAP dataset
104
+ def get_cap_id_to_index():
105
+ if not hasattr(get_cap_id_to_index, 'index') or get_cap_id_to_index.index is None:
106
+ import pickle # Lazy import if needed elsewhere, but here it's for logging only
107
+ cap_dataset = get_cap_dataset()
108
+ if cap_dataset is not None:
109
+ get_cap_id_to_index.index = {doc['cid']: i for i, doc in enumerate(cap_dataset) if 'cid' in doc}
110
+ logger.info(" Precomputed CAP CID to index mapping.")
111
+ else:
112
+ get_cap_id_to_index.index = {}
113
+ logger.error("CAP dataset not available for index mapping.")
114
+ return get_cap_id_to_index.index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ get_cap_id_to_index.index = None
117
+ logger.info("✅ CAP ID-to-index lazy-loader defined.")
118
 
119
  # State dictionary for jurisdiction
120
  STATES = {
121
+ "AL": "Alabama", "AK": "Alaska", "AZ": "Arizona", "AR": "Arkansas", "CA": "California",
122
+ "CO": "Colorado", "CT": "Connecticut", "DE": "Delaware", "FL": "Florida", "GA": "Georgia",
123
+ "HI": "Hawaii", "ID": "Idaho", "IL": "Illinois", "IN": "Indiana", "IA": "Iowa",
124
+ "KS": "Kansas", "KY": "Kentucky", "LA": "Louisiana", "ME": "Maine", "MD": "Maryland",
125
+ "MA": "Massachusetts", "MI": "Michigan", "MN": "Minnesota", "MS": "Mississippi", "MO": "Missouri",
126
+ "MT": "Montana", "NE": "Nebraska", "NV": "Nevada", "NH": "New Hampshire", "NJ": "New Jersey",
127
+ "NM": "New Mexico", "NY": "New York", "NC": "North Carolina", "ND": "North Dakota", "OH": "Ohio",
128
+ "OK": "Oklahoma", "OR": "Oregon", "PA": "Pennsylvania", "RI": "Rhode Island", "SC": "South Carolina",
129
+ "SD": "South Dakota", "TN": "Tennessee", "TX": "Texas", "UT": "Utah", "VT": "Vermont",
130
+ "VA": "Virginia", "WA": "Washington", "WV": "West Virginia", "WI": "Wisconsin", "WY": "Wyoming",
131
+ "Federal": "Federal", "All States": "All States", "Other": "Other States"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  }
133
+ logger.info("✅ States dictionary loaded.")
134
 
135
+ # Verdict Ai api call function (updated for streaming)
136
+ def ask_grok(messages, stream=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  try:
138
+ headers = {
139
+ "Accept": "application/json",
140
+ "Content-Type": "application/json",
141
+ "Authorization": f"Bearer {GROK_API_TOKEN}"
142
+ }
143
  payload = {
144
  "messages": messages,
145
+ "model": "grok-4-0709",
146
+ "stream": stream,
147
+ "temperature": 0.1,
148
+ "max_tokens": 131072, # High value for long responses
149
+ "search_parameters": {
150
+ "mode": "on"
151
  }
152
  }
153
+ logger.info(f"Grok payload: {payload}")
154
+ response = requests.post(GROK_API_URL, headers=headers, json=payload, stream=stream)
155
+ logger.info(f"Grok response status: {response.status_code}")
156
  response.raise_for_status()
157
+ if stream:
158
+ def stream_gen():
159
+ logger.info("Starting Grok stream...")
160
+ for raw_chunk in response.iter_lines():
161
+ chunk = raw_chunk.decode("utf-8").strip()
162
+ if not chunk:
163
+ continue # Skip empty lines
164
+ chunk_data = chunk.replace("data: ", "")
165
+ logger.info(f"Raw chunk: {chunk_data}")
166
+ if chunk_data == "[DONE]":
167
+ yield "data: [DONE]\n\n"
168
+ break
169
+ try:
170
+ result = json.loads(chunk_data)
171
+ delta = result.get("choices", [{}])[0].get("delta", {})
172
+ content = delta.get("content", "")
173
+ if content:
174
+ yield f'data: {{"chunk": {json.dumps(content)}}}\n\n'
175
+ except Exception as e:
176
+ logger.warning(f"Grok JSON parse error: {e} | chunk_data: {chunk_data}")
177
+ yield f'data: {{"chunk": "[Unrecognized Grok output]"}}\n\n'
178
+ logger.info("Stream ended.")
179
+ return stream_gen()
180
  else:
181
+ result = response.json()
182
+ logger.info(f"Grok non-stream result: {result}")
183
+ if "choices" in result and result["choices"] and "message" in result["choices"][0] and "content" in result["choices"][0]["message"]:
184
+ content = result["choices"][0]["message"]["content"]
185
+ if len(content) > 65536:
186
+ content = content[:65536] + "... [Truncated]"
187
+ return content.strip()
188
+ return "[No response]"
189
+ except requests.exceptions.HTTPError as http_err:
190
+ logger.error(f"Grok HTTP error: {http_err}, Response: {response.text if 'response' in locals() else 'N/A'}")
191
+ if stream:
192
+ def error_gen():
193
+ yield f'data: {{"error": "Grok API error: {str(http_err)}"}}\n\n'
194
+ yield "data: [DONE]\n\n"
195
+ return error_gen()
196
+ return "[Grok Error] " + str(http_err)
197
  except Exception as e:
198
+ logger.error(f"Grok general error: {type(e).__name__}: {str(e)}")
199
+ if stream:
200
+ def error_gen():
201
+ yield f'data: {{"error": "{str(e)}"}}\n\n'
202
+ yield "data: [DONE]\n\n"
203
+ return error_gen()
204
+ return "[No response]"
205
+
206
+ def extract_text_from_file(file_path):
207
  try:
208
+ ext = os.path.splitext(file_path)[1].lower()
209
+ text = ""
210
+ if ext == '.pdf':
211
+ with pdfplumber.open(file_path) as pdf:
212
+ text = "\n".join([page.extract_text() or "" for page in pdf.pages])
213
+ elif ext == '.docx':
214
+ doc = Document(file_path)
215
+ text = "\n".join([para.text for para in doc.paragraphs])
216
+ elif ext == '.txt':
217
+ with open(file_path, 'r', encoding='utf-8') as f:
218
+ text = f.read()
219
+ else:
220
+ text = f"Non-text file uploaded: {os.path.basename(file_path)}. Analyze if image or other."
221
+ logger.info(f"Extracted text length: {len(text)} from {ext} file")
222
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  except Exception as e:
224
+ logger.error(f"File extraction error: {str(e)}")
225
  return ""
226
 
227
  def classify_prompt(prompt):
228
  prompt_lower = prompt.lower()
229
  if "summarize" in prompt_lower:
230
  return "document_analysis" # Treat summarize as analysis for routing
231
+ if any(k in prompt_lower for k in ["irac", "issue", "rule", "analysis", "conclusion", "brief", "memorandum", "memo"]):
232
  return "irac"
233
  elif any(k in prompt_lower for k in ["case", "precedent", "law"]):
234
  return "case_law"
235
  elif any(k in prompt_lower for k in ["statute", "krs"]):
236
  return "statute"
237
+ elif any(k in prompt_lower for k in ["draft", "write", "generate", "petition", "letter", "contract", "title opinion"]):
238
  return "document_creation"
239
  elif any(k in prompt_lower for k in ["review", "summarize", "clause", "red flags"]):
240
  return "document_analysis"
 
260
  return "legal_strategy"
261
  return "general_qa"
262
 
263
+ def create_legal_docx(content, jurisdiction, filename):
264
+ doc = Document()
265
+ # Set margins and font
266
+ sections = doc.sections
267
+ for section in sections:
268
+ section.top_margin = Inches(1)
269
+ section.bottom_margin = Inches(1)
270
+ section.left_margin = Inches(1)
271
+ section.right_margin = Inches(1)
272
+ # Case Caption (example placeholder)
273
+ caption = doc.add_paragraph()
274
+ caption.alignment = WD_ALIGN_PARAGRAPH.CENTER
275
+ run = caption.add_run("IN THE [COURT NAME] OF [JURISDICTION]\n")
276
+ run.bold = True
277
+ run.font.size = Pt(12)
278
+ caption.add_run("[Plaintiff] v. [Defendant]\nCase No: [Number]")
279
+ # Add content (assume content has sections marked with # for headings)
280
+ lines = content.split('\n')
281
+ for line in lines:
282
+ if line.startswith('# '):
283
+ heading = doc.add_heading(line[2:], level=1)
284
+ heading.alignment = WD_ALIGN_PARAGRAPH.CENTER
285
+ elif line.startswith('## '):
286
+ doc.add_heading(line[3:], level=2)
287
  else:
288
+ p = doc.add_paragraph(line)
289
+ p.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY
290
+ # Signature Block
291
+ doc.add_paragraph("\nRespectfully submitted,")
292
+ sig = doc.add_paragraph("[Attorney Name]\n[Bar Number]\n[Firm]\n[Address]\n[Phone]\n[Email]")
293
+ sig.alignment = WD_ALIGN_PARAGRAPH.LEFT
294
+ # Certificate of Service
295
+ doc.add_heading("CERTIFICATE OF SERVICE", level=1)
296
+ doc.add_paragraph("I hereby certify that a true and correct copy of the foregoing was served on [date] via [method] to:\n[Recipient]")
297
+ # Notary Acknowledgement (if applicable)
298
+ doc.add_heading("NOTARY ACKNOWLEDGEMENT", level=1)
299
+ doc.add_paragraph("[State/County]\nSubscribed and sworn to before me this [date] by [name].\n\nNotary Public")
300
+ doc.save(filename)
301
+ return filename
302
+
303
+ def route_model(messages, task_type, files=None, search_web=False, jurisdiction="KY"):
304
+ logger.info(f"Routing messages, Task: {task_type}, Web Search: {search_web}, Jurisdiction: {jurisdiction}")
305
+ rag_context = ""
306
+ prompt = messages[-1]['content'] # Use last user message as prompt for classification etc.
307
+ if task_type in ["case_law", "irac", "statute"]: # Skip RAG for document_creation/summaries
308
+ cap_dataset = get_cap_dataset()
309
+ if cap_dataset is not None:
310
+ combined_results = retrieve_context(prompt, task_type, jurisdiction)
311
+ # Filter by jurisdiction if specified
312
+ if jurisdiction and jurisdiction != "All States":
313
+ state_name = STATES.get(jurisdiction, "").lower()
314
+ state_code = jurisdiction.lower()
315
+ variants = [state_code, state_name, f"{state_code}.", state_name.replace(" ", "")]
316
+ combined_results = [r for r in combined_results if any(v in (r.get('citation', '') + r.get('name', '') + r.get('snippet', '')).lower() for v in variants)]
317
+ if combined_results:
318
+ rag_context = "Retrieved legal authorities (case law and statutes):\n" + "\n".join(
319
+ [f"{i+1}. [{auth.get('source', 'Unknown')}] {auth['name']}, {auth['citation']}: \"{auth['snippet']}\"" for i, auth in enumerate(combined_results)]
320
+ )
321
+ messages[-1]['content'] = f"{prompt}\n\n{rag_context}"
322
+ if task_type == "document_creation":
323
+ # Reset messages to only current prompt to avoid history accumulation
324
+ prompt = messages[-1]['content']
325
+ draft_messages = [{'role': 'user', 'content': prompt}]
326
+ # Route directly to fine-tuned GPT for document creation
327
+ gpt_response = ask_gpt41_mini(prompt, jurisdiction) # Adjust to use full messages if gpt_helpers supports
328
+ logger.info(f"GPT-4.1-mini response length: {len(gpt_response)} | Content snippet: {gpt_response[:200]}...")
329
+ if not gpt_response.strip():
330
+ logger.warning("Empty response from GPT-4.1-mini; possible content filtering.")
331
+ yield f'data: {{"error": "Empty draft from GPT-4.1-mini - prompt may be filtered. Try rephrasing."}}\n\n'
332
+ yield "data: [DONE]\n\n"
333
+ return
334
+ # Truncate if too long to prevent token issues
335
+ MAX_GPT_LEN = 20000
336
+ if len(gpt_response) > MAX_GPT_LEN:
337
+ gpt_response = gpt_response[:MAX_GPT_LEN] + "\n[Truncated: GPT response too long; refining may be needed.]"
338
+ logger.warning(f"Truncated GPT response to {MAX_GPT_LEN} chars.")
339
+ editor_messages = draft_messages + [{'role': 'assistant', 'content': gpt_response}]
340
+ editor_prompt = build_editor_prompt(prompt, task_type, jurisdiction, gpt_response, rag_context) # But to make contextual, perhaps use full
341
+ editor_messages.append({'role': 'user', 'content': editor_prompt}) # Or append
342
+ # Use non-stream for Grok to avoid streaming issues
343
+ try:
344
+ full_grok_response = ask_grok(editor_messages, stream=False) # CHANGED: Non-stream for reliability
345
+ logger.info(f"Grok polish response length: {len(full_grok_response)} | Snippet: {full_grok_response[:200]}...")
346
+ if not full_grok_response.strip():
347
+ logger.warning("Empty response from Grok; using GPT draft.")
348
+ full_response = gpt_response
349
+ else:
350
+ full_response = full_grok_response
351
+ except Exception as e:
352
+ logger.error(f"Grok non-stream error: {str(e)}. Using GPT draft.")
353
+ full_response = gpt_response
354
+ # Yield as faux stream chunks
355
+ chunks = [full_response[i:i+200] for i in range(0, len(full_response), 200)] # Split for streaming feel
356
+ for part in chunks:
357
+ yield f'data: {{"chunk": {json.dumps(part)}}}\n\n' # Use json.dumps for safe escaping
358
+ # Create doc and send download URL
359
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
360
+ filename = f"/tmp/legal_doc_{timestamp}.docx"
361
+ create_legal_docx(full_response, jurisdiction, filename)
362
+ yield f'data: {{"download_url": "/download/legal_doc_{timestamp}.docx"}}\n\n'
363
+ yield "data: [DONE]\n\n"
364
+ return
365
  else:
366
+ try:
367
+ # Build system prompt contextual
368
+ system_content = build_grok_prompt(prompt, task_type, jurisdiction, rag_context) # But since messages have it, prepend if not
369
+ system_content += "\nStick strictly to the provided retrieved context for your response. Do not add information, cases, or statutes not explicitly in the context to avoid hallucinations. If context is insufficient, state so clearly."
370
+ if 'CourtListener' in rag_context:
371
+ system_content += "\nPrioritize CourtListener results for accuracy: Quote key snippets, cite cases, and polish into a structured response (e.g., IRAC format for analysis tasks)."
372
+ if messages[0]['role'] != 'system':
373
+ messages = [{'role': 'system', 'content': system_content}] + messages
374
+ stream_grok = ask_grok(messages, stream=True)
375
+ except Exception as e:
376
+ logger.error(f"Grok failed: {e}. Falling back to GPT-4o.")
377
+ grok_response = ask_gpt4o(messages[-1]['content']) # Fallback, adjust to full if possible
378
+ yield f'data: {{"chunk": "{grok_response}"}}\n\n'
379
+ yield "data: [DONE]\n\n"
380
+ return
381
+ # Task-specific processing
382
+ # For streaming, skip or adapt; here, stream raw
383
+ for chunk in stream_grok:
384
+ yield chunk
385
+ yield "data: [DONE]\n\n"
386
 
387
+ def ask_gpt4o(prompt):
388
+ try:
389
+ irac_system = "If the task involves legal analysis, polish and organize the output into clear IRAC format. Otherwise, organize appropriately without IRAC."
390
+ response = openai_client.chat.completions.create(
391
+ model="gpt-4o",
392
+ messages=[
393
+ {
394
+ "role": "system",
395
+ "content": (
396
+ f"You are the final editor for a legal research assistant. {irac_system} "
397
+ "Ensure high quote density from retrieved authorities and include relevant facts from the cited cases. "
398
+ "Maintain accurate citations. Do not paraphrase legal holdings when direct quotes are available. "
399
+ "Do not cite or reference any case law, statutes, or authorities that are not explicitly provided in the retrieved context or user input."
400
+ )
401
+ },
402
+ {"role": "user", "content": prompt}
403
+ ],
404
+ temperature=0.3,
405
+ max_tokens=65536
406
+ )
407
+ return response.choices[0].message.content
408
+ except Exception as e:
409
+ logger.error(f"GPT-4o error: {str(e)}")
410
+ return f"[GPT-4o Error] {str(e)}"
411
 
412
  def summarize_document(files):
413
+ def gen():
414
+ if files and isinstance(files, list) and files:
415
+ texts = [extract_text_from_file(f) for f in files]
416
+ text = "\n".join(texts)
417
+ if text:
418
+ summary = ask_grok([{"role": "user", "content": f"Summarize the following document(s): {text[:10000]}"}], stream=False) # Explicitly non-stream
419
+ full_response = f"Summary: {summary}"
420
+ chunks = [full_response[i:i+200] for i in range(0, len(full_response), 200)] # Split for streaming feel
421
+ for part in chunks:
422
+ yield f'data: {{"chunk": {json.dumps(part)}}}\n\n'
423
+ yield "data: [DONE]\n\n"
424
+ else:
425
+ yield f'data: {{"chunk": "No text extracted from file."}}\n\n'
426
+ yield "data: [DONE]\n\n"
427
+ else:
428
+ yield f'data: {{"chunk": "Please upload a file to summarize."}}\n\n'
429
+ yield "data: [DONE]\n\n"
430
+ return gen
431
 
432
  def analyze_document(files):
433
+ def gen():
434
+ if files:
435
+ texts = [extract_text_from_file(f) for f in files]
436
+ text = "\n".join(texts)
437
+ if text:
438
+ analysis = ask_grok([{"role": "user", "content": f"Analyze the following document(s) for legal issues, risks, or key clauses: {text[:10000]}"}], stream=False) # Explicitly non-stream
439
+ full_response = f"Analysis: {analysis}"
440
+ chunks = [full_response[i:i+200] for i in range(0, len(full_response), 200)]
441
+ for part in chunks:
442
+ yield f'data: {{"chunk": {json.dumps(part)}}}\n\n'
443
+ yield "data: [DONE]\n\n"
444
+ else:
445
+ yield f'data: {{"chunk": "No text extracted from file."}}\n\n'
446
+ yield "data: [DONE]\n\n"
447
+ else:
448
+ yield f'data: {{"chunk": "No file uploaded for analysis."}}\n\n'
449
+ yield "data: [DONE]\n\n"
450
+ return gen
451
 
452
  def check_issues(files):
453
+ def gen():
454
+ if files:
455
+ texts = [extract_text_from_file(f) for f in files]
456
+ text = "\n".join(texts)
457
+ if text:
458
+ issues = ask_grok([{"role": "user", "content": f"Check for red flags, unusual clauses, or potential issues in this legal document(s) and highlight them: {text[:10000]}"}], stream=False) # Explicitly non-stream
459
+ full_response = f"Highlighted Issues: {issues}"
460
+ chunks = [full_response[i:i+200] for i in range(0, len(full_response), 200)]
461
+ for part in chunks:
462
+ yield f'data: {{"chunk": {json.dumps(part)}}}\n\n'
463
+ yield "data: [DONE]\n\n"
464
+ else:
465
+ yield f'data: {{"chunk": "No text extracted from file."}}\n\n'
466
+ yield "data: [DONE]\n\n"
467
+ else:
468
+ yield f'data: {{"chunk": "No file uploaded to check."}}\n\n'
469
+ yield "data: [DONE]\n\n"
470
+ return gen
471
+
472
+ # Error handlers to always return JSON
473
+ @app.errorhandler(400)
474
+ def bad_request(error):
475
+ return jsonify({'error': 'Bad request'}), 400
476
+
477
+ @app.errorhandler(404)
478
+ def not_found(error):
479
+ return jsonify({'error': 'Not found'}), 404
480
+
481
+ @app.errorhandler(405)
482
+ def method_not_allowed(error):
483
+ return jsonify({'error': 'Method not allowed'}), 405
484
+
485
+ @app.errorhandler(500)
486
+ def internal_error(error):
487
+ return jsonify({'error': 'Internal server error'}), 500
488
+
489
+ @app.errorhandler(Exception)
490
+ def handle_exception(e):
491
+ logger.error(f"Unhandled exception: {str(e)}")
492
+ return jsonify({'error': str(e)}), 500
493
+
494
+ # Flask routes
495
+ @app.route('/')
496
+ def index():
497
+ return send_from_directory('.', 'index.html')
498
+
499
+ @app.route('/api/chat', methods=['POST'])
500
+ def api_chat():
501
+ temp_paths = [] # Initialize here for finally block
502
+ def generate():
503
+ try:
504
+ # Early check for missing data
505
+ if 'payload' not in request.form:
506
+ yield f'data: {{"error": "Missing payload in request"}}\n\n'
507
+ yield "data: [DONE]\n\n"
508
+ return
509
+ payload = json.loads(request.form['payload'])
510
+ messages = payload['messages']
511
+ jurisdiction = payload['jurisdiction']
512
+ irac_mode = payload['irac_mode']
513
+ search_web = payload['web_search']
514
+ uploaded_files = request.files.getlist('file')
515
+ file_texts = []
516
+ if uploaded_files:
517
+ for file in uploaded_files:
518
+ if file.filename:
519
+ filename = secure_filename(file.filename)
520
+ temp_path = os.path.join('/tmp', filename)
521
+ file.save(temp_path)
522
+ file_text = extract_text_from_file(temp_path)
523
+ file_texts.append(file_text)
524
+ temp_paths.append(temp_path)
525
+ file_text_combined = "\n".join(file_texts)
526
+ prompt = messages[-1]['content'] # for classification
527
+ task_type = classify_prompt(prompt)
528
+ if irac_mode:
529
+ task_type = "irac"
530
+ # Append file text to last user message if present
531
+ if file_text_combined:
532
+ messages[-1]['content'] += "\nAttached file content(s): " + file_text_combined[:10000]
533
+ if "summarize" in prompt.lower():
534
+ task_type = "document_analysis"
535
+ gen_func = summarize_document(temp_paths)
536
+ for chunk in gen_func():
537
+ yield chunk
538
+ elif "analyze" in prompt.lower():
539
+ task_type = "document_analysis"
540
+ gen_func = analyze_document(temp_paths)
541
+ for chunk in gen_func():
542
+ yield chunk
543
+ elif "check" in prompt.lower() or "issues" in prompt.lower() or "highlight" in prompt.lower():
544
+ task_type = "document_analysis"
545
+ gen_func = check_issues(temp_paths)
546
+ for chunk in gen_func():
547
+ yield chunk
548
+ elif "generate" in prompt.lower() or "draft" in prompt.lower():
549
+ task_type = "document_creation"
550
+ for line in route_model(messages, task_type, temp_paths, search_web, jurisdiction):
551
+ yield line
552
+ else:
553
+ for line in route_model(messages, task_type, temp_paths, search_web, jurisdiction):
554
+ yield line
555
+ logger.info("Grok response streamed.")
556
+ except Exception as e:
557
+ logger.error(f"Error in /api/chat: {str(e)}")
558
+ yield f'data: {{"error": "{str(e)}"}}\n\n'
559
+ yield "data: [DONE]\n\n"
560
+ finally:
561
+ # Cleanup temp files (no context needed for os.remove)
562
+ for temp_path in temp_paths:
563
+ try:
564
+ os.remove(temp_path)
565
+ except Exception as cleanup_e:
566
+ logger.error(f"Cleanup error: {str(cleanup_e)}")
567
+ return Response(stream_with_context(generate()), mimetype='text/event-stream')
568
+
569
+ @app.route('/download/<filename>', methods=['GET'])
570
+ def download(filename):
571
+ return send_file(os.path.join('/tmp', filename), as_attachment=True)
572
+
573
+ @app.route('/health', methods=['GET'])
574
+ def health():
575
+ return "OK", 200
576
+
577
+ if __name__ == '__main__':
578
+ logger.info("✅ All init complete. Starting Flask app...")
579
+ print("Flask app starting...") # Fallback print
580
+ app.run(host='0.0.0.0', port=7860)