Spaces:
Runtime error
Runtime error
Update App.py
Browse files
App.py
CHANGED
|
@@ -1,31 +1,18 @@
|
|
| 1 |
-
#
|
| 2 |
-
# This is the updated main script. Copy-paste this over your existing script.py.
|
| 3 |
-
# Changes:
|
| 4 |
-
# - Fixed any potential issues with prompt_lower by ensuring consistent use of prompt.lower() (though it was already correct).
|
| 5 |
-
# - Split route_model responsibilities: Moved retrieval functions to retrieval.py, prompt building to prompt_builder.py, post-processing to post_processing.py.
|
| 6 |
-
# - Imports: Add 'from retrieval import *', 'from prompt_builder import *', 'from post_processing import *' after your existing imports.
|
| 7 |
-
# - For synchronous loads: hf_hub_download already caches, so subsequent calls are fast. To avoid first-time blocking, I've added a background preloader thread that pre-downloads a limited number of clusters (e.g., first 10) at startup. If you have many clusters, create a separate script (see below) to pre-download all offline.
|
| 8 |
-
# - No need for a separate script unless you want to pre-download ALL clusters (which could be storage-intensive). If yes, see the optional pre_download_clusters.py below.
|
| 9 |
-
# - In route_model, now calls the split functions.
|
| 10 |
-
# - Passed necessary globals (e.g., cap_dataset, cap_id_to_index) to semantic_search.
|
| 11 |
-
# - Ensured municipal_embeddings is loaded (assume it's global).
|
| 12 |
-
# - Added threading for preloading some clusters at startup.
|
| 13 |
-
|
| 14 |
-
import gradio as gr
|
| 15 |
-
from openai import OpenAI
|
| 16 |
import requests
|
| 17 |
import os
|
| 18 |
import logging
|
| 19 |
from datetime import datetime
|
| 20 |
import pdfplumber
|
| 21 |
-
from
|
|
|
|
|
|
|
| 22 |
import re
|
| 23 |
-
from datasets import load_dataset,
|
| 24 |
from sentence_transformers import SentenceTransformer
|
| 25 |
import torch
|
| 26 |
import numpy as np
|
| 27 |
import shutil
|
| 28 |
-
import pyarrow.parquet as pq
|
| 29 |
from huggingface_hub import hf_hub_download
|
| 30 |
import pickle
|
| 31 |
import faiss
|
|
@@ -33,340 +20,221 @@ import threading
|
|
| 33 |
import subprocess
|
| 34 |
from task_processing import process_task_response
|
| 35 |
from gpt_helpers import ask_gpt41_mini
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
#
|
| 38 |
-
from retrieval import *
|
| 39 |
-
from prompt_builder import *
|
| 40 |
-
from post_processing import *
|
| 41 |
|
| 42 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
| 43 |
-
# Add or update this section in script.py
|
| 44 |
-
# Ensure this is placed after imports but before any dataset loading or function definitions
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
#
|
|
|
|
| 49 |
hf_token = os.environ.get("HF_TOKEN", "")
|
| 50 |
-
if not hf_token:
|
| 51 |
-
logger.warning("HF_TOKEN not set; SaulLM endpoint may require authentication and gated repos may not be accessible.")
|
| 52 |
-
|
| 53 |
-
# Authenticate for gated Hugging Face repos (e.g., for centroids download)
|
| 54 |
if hf_token:
|
| 55 |
login(hf_token)
|
| 56 |
-
logger.info("Authenticated with Hugging Face token for gated repos.")
|
| 57 |
else:
|
| 58 |
-
logger.warning("
|
| 59 |
|
| 60 |
# Check environment variables
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
logger.
|
| 69 |
-
|
| 70 |
-
logger.
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
logger.info(f"Downloaded and saved file to {save_path}")
|
| 80 |
-
except Exception as e:
|
| 81 |
-
logger.error(f"Failed to download from {url}: {str(e)}")
|
| 82 |
-
|
| 83 |
-
# Download the centroids file if not present
|
| 84 |
-
centroid_url = "https://huggingface.co/datasets/laion/Caselaw_Access_Project_embeddings/blob/main/TeraflopAI___Caselaw_Access_Project_centroids.parquet"
|
| 85 |
-
centroid_path = "TeraflopAI___Caselaw_Access_Project_centroids.parquet"
|
| 86 |
-
download_file_if_not_exists(centroid_url, centroid_path)
|
| 87 |
-
# Load HF token for SaulLM endpoint
|
| 88 |
-
hf_token = os.environ.get("HF_TOKEN", "")
|
| 89 |
-
if not hf_token:
|
| 90 |
-
logger.warning("HF_TOKEN not set; SaulLM endpoint may require authentication")
|
| 91 |
-
|
| 92 |
-
import requests
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
# Initialize OpenAI client
|
| 96 |
-
openai_client = OpenAI(api_key=OPENAI_API_KEY)
|
| 97 |
-
|
| 98 |
-
# SaulLM endpoint
|
| 99 |
-
SAUL_ENDPOINT = "https://l4tuv4j9bu616t5x.us-east-1.aws.endpoints.huggingface.cloud"
|
| 100 |
-
|
| 101 |
-
# Persistent storage path for dataset
|
| 102 |
-
LOCAL_PATH = "/data/cap_dataset"
|
| 103 |
-
dataset_info_path = os.path.join(LOCAL_PATH, 'dataset_info.json')
|
| 104 |
-
if os.path.exists(dataset_info_path):
|
| 105 |
-
cap_dataset = load_from_disk(LOCAL_PATH)
|
| 106 |
else:
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
logger.
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
logger.
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
run_build_municipal_faiss_once()
|
| 152 |
-
|
| 153 |
-
# Load municipal embeddings dataset
|
| 154 |
-
if os.path.exists(MUNICIPAL_EMBEDDINGS_PATH):
|
| 155 |
-
municipal_embeddings = load_from_disk(MUNICIPAL_EMBEDDINGS_PATH)
|
| 156 |
-
else:
|
| 157 |
-
logger.error("Municipal embeddings not found. Ensure prepare_municipal_embeddings.py ran successfully.")
|
| 158 |
-
municipal_embeddings = None # Fallback or error handling
|
| 159 |
-
|
| 160 |
-
# Load municipal html dataset
|
| 161 |
-
if os.path.exists(MUNICIPAL_HTML_PATH):
|
| 162 |
-
municipal_html = load_from_disk(MUNICIPAL_HTML_PATH)
|
| 163 |
-
else:
|
| 164 |
-
logger.error("Municipal html not found. Ensure prepare_municipal_embeddings.py ran successfully.")
|
| 165 |
-
municipal_html = None
|
| 166 |
-
|
| 167 |
-
# Load municipal citation dataset
|
| 168 |
-
if os.path.exists(MUNICIPAL_CITATION_PATH):
|
| 169 |
-
municipal_citation = load_from_disk(MUNICIPAL_CITATION_PATH)
|
| 170 |
-
else:
|
| 171 |
-
logger.error("Municipal citation not found. Ensure prepare_municipal_embeddings.py ran successfully.")
|
| 172 |
-
municipal_citation = None
|
| 173 |
-
|
| 174 |
-
# Precompute CID to index mapping for CAP dataset
|
| 175 |
-
cap_id_to_index = {doc['cid']: i for i, doc in enumerate(cap_dataset) if 'cid' in doc}
|
| 176 |
-
|
| 177 |
-
# Preload some clusters in background (e.g., clusters 0-9)
|
| 178 |
-
def preload_clusters():
|
| 179 |
-
for cluster_id in range(10): # Adjust range as needed
|
| 180 |
-
try:
|
| 181 |
-
load_cluster_vectors(cluster_id, model="gte-large")
|
| 182 |
-
logger.info(f"Preloaded cluster {cluster_id}")
|
| 183 |
-
except Exception as e:
|
| 184 |
-
logger.error(f"Preload failed for cluster {cluster_id}: {e}")
|
| 185 |
|
| 186 |
-
|
|
|
|
| 187 |
|
| 188 |
# State dictionary for jurisdiction
|
| 189 |
STATES = {
|
| 190 |
-
"AL": "Alabama",
|
| 191 |
-
"
|
| 192 |
-
"
|
| 193 |
-
"
|
| 194 |
-
"
|
| 195 |
-
"
|
| 196 |
-
"
|
| 197 |
-
"
|
| 198 |
-
"
|
| 199 |
-
"
|
| 200 |
-
"
|
| 201 |
-
"ID": "Idaho",
|
| 202 |
-
"IL": "Illinois",
|
| 203 |
-
"IN": "Indiana",
|
| 204 |
-
"IA": "Iowa",
|
| 205 |
-
"KS": "Kansas",
|
| 206 |
-
"KY": "Kentucky",
|
| 207 |
-
"LA": "Louisiana",
|
| 208 |
-
"ME": "Maine",
|
| 209 |
-
"MD": "Maryland",
|
| 210 |
-
"MA": "Massachusetts",
|
| 211 |
-
"MI": "Michigan",
|
| 212 |
-
"MN": "Minnesota",
|
| 213 |
-
"MS": "Mississippi",
|
| 214 |
-
"MO": "Missouri",
|
| 215 |
-
"MT": "Montana",
|
| 216 |
-
"NE": "Nebraska",
|
| 217 |
-
"NV": "Nevada",
|
| 218 |
-
"NH": "New Hampshire",
|
| 219 |
-
"NJ": "New Jersey",
|
| 220 |
-
"NM": "New Mexico",
|
| 221 |
-
"NY": "New York",
|
| 222 |
-
"NC": "North Carolina",
|
| 223 |
-
"ND": "North Dakota",
|
| 224 |
-
"OH": "Ohio",
|
| 225 |
-
"OK": "Oklahoma",
|
| 226 |
-
"OR": "Oregon",
|
| 227 |
-
"PA": "Pennsylvania",
|
| 228 |
-
"RI": "Rhode Island",
|
| 229 |
-
"SC": "South Carolina",
|
| 230 |
-
"SD": "South Dakota",
|
| 231 |
-
"TN": "Tennessee",
|
| 232 |
-
"TX": "Texas",
|
| 233 |
-
"UT": "Utah",
|
| 234 |
-
"VT": "Vermont",
|
| 235 |
-
"VA": "Virginia",
|
| 236 |
-
"WA": "Washington",
|
| 237 |
-
"WV": "West Virginia",
|
| 238 |
-
"WI": "Wisconsin",
|
| 239 |
-
"WY": "Wyoming",
|
| 240 |
-
"Federal": "Federal",
|
| 241 |
-
"All States": "All States",
|
| 242 |
-
"Other": "Other States"
|
| 243 |
}
|
|
|
|
| 244 |
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
rag_context = ""
|
| 249 |
-
if task_type in ["case_law", "irac", "statute"]:
|
| 250 |
-
cap_results = semantic_search(prompt, top_k=5)
|
| 251 |
-
municipal_results = municipal_search(prompt, top_k=5)
|
| 252 |
-
combined_results = cap_results + municipal_results
|
| 253 |
-
|
| 254 |
-
# Filter by jurisdiction if specified (e.g., "KY" for Kentucky)
|
| 255 |
-
if jurisdiction and jurisdiction != "All States":
|
| 256 |
-
state_name = STATES.get(jurisdiction, "")
|
| 257 |
-
state_code = jurisdiction # e.g., "KY"
|
| 258 |
-
combined_results = [r for r in combined_results if state_code in r['citation'] or state_name in r['citation'] or state_code in r['name'] or state_name in r['name']]
|
| 259 |
-
|
| 260 |
-
if combined_results:
|
| 261 |
-
rag_context = "Retrieved legal authorities (case law and statutes):\n" + "\n".join([f"{i+1}. [{auth.get('source', 'Unknown')}] {auth['name']}, {auth['citation']}: \"{auth['snippet']}\"" for i, auth in enumerate(combined_results)])
|
| 262 |
-
|
| 263 |
-
prompt = f"User prompt: {prompt}\n\n{rag_context}"
|
| 264 |
-
|
| 265 |
-
saul_response = ask_saul(prompt, task_type, jurisdiction)
|
| 266 |
-
|
| 267 |
-
# Task-specific processing (existing code)
|
| 268 |
-
saul_response = process_task_response(task_type, saul_response, prompt, jurisdiction)
|
| 269 |
-
|
| 270 |
-
if search_web:
|
| 271 |
-
web_data = google_search(prompt)
|
| 272 |
-
saul_response = f"Google Search results: {web_data}\n{saul_response}"
|
| 273 |
-
|
| 274 |
-
editor_prompt = build_editor_prompt(prompt, task_type, jurisdiction, saul_response, rag_context)
|
| 275 |
-
|
| 276 |
-
final_response = ask_gpt4o(editor_prompt)
|
| 277 |
-
|
| 278 |
-
final_response = ground_statutes(final_response, jurisdiction)
|
| 279 |
-
|
| 280 |
-
return final_response
|
| 281 |
-
|
| 282 |
-
def ask_saul(messages, task_type, jurisdiction):
|
| 283 |
try:
|
| 284 |
-
headers = {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
payload = {
|
| 286 |
"messages": messages,
|
| 287 |
-
"
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
| 290 |
}
|
| 291 |
}
|
| 292 |
-
logger.info(f"
|
| 293 |
-
response = requests.post(
|
|
|
|
| 294 |
response.raise_for_status()
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
else:
|
| 301 |
-
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
except Exception as e:
|
| 304 |
-
logger.error(f"
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
try:
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
logger.error(f"GPT-4.1 Mini error: {str(e)}")
|
| 325 |
-
return f"[GPT-4.1 Mini Error] {str(e)}"
|
| 326 |
-
|
| 327 |
-
def ask_gpt4o(prompt):
|
| 328 |
-
try:
|
| 329 |
-
response = openai_client.chat.completions.create(
|
| 330 |
-
model="gpt-4o",
|
| 331 |
-
messages=[
|
| 332 |
-
{"role": "system", "content": (
|
| 333 |
-
"You are the final editor for a legal research assistant. Polish and organize the output into clear IRAC format. "
|
| 334 |
-
"Ensure high quote density from retrieved authorities and include relevant facts from the cited cases. "
|
| 335 |
-
"Maintain accurate citations. Do not paraphrase legal holdings when direct quotes are available."
|
| 336 |
-
)},
|
| 337 |
-
{"role": "user", "content": prompt}
|
| 338 |
-
],
|
| 339 |
-
temperature=0.3,
|
| 340 |
-
max_tokens=16384
|
| 341 |
-
)
|
| 342 |
-
return response.choices[0].message.content
|
| 343 |
-
except Exception as e:
|
| 344 |
-
logger.error(f"GPT-4o error: {str(e)}")
|
| 345 |
-
return f"[GPT-4o Error] {str(e)}"
|
| 346 |
-
|
| 347 |
-
def extract_text_from_pdf(file_path):
|
| 348 |
-
try:
|
| 349 |
-
with pdfplumber.open(file_path) as pdf:
|
| 350 |
-
text = ""
|
| 351 |
-
for page in pdf.pages:
|
| 352 |
-
text += page.extract_text() or ""
|
| 353 |
-
logger.info(f"Extracted text length: {len(text)}")
|
| 354 |
-
return text
|
| 355 |
except Exception as e:
|
| 356 |
-
logger.error(f"
|
| 357 |
return ""
|
| 358 |
|
| 359 |
def classify_prompt(prompt):
|
| 360 |
prompt_lower = prompt.lower()
|
| 361 |
if "summarize" in prompt_lower:
|
| 362 |
return "document_analysis" # Treat summarize as analysis for routing
|
| 363 |
-
if any(k in prompt_lower for k in ["irac", "issue", "rule", "analysis", "conclusion"]):
|
| 364 |
return "irac"
|
| 365 |
elif any(k in prompt_lower for k in ["case", "precedent", "law"]):
|
| 366 |
return "case_law"
|
| 367 |
elif any(k in prompt_lower for k in ["statute", "krs"]):
|
| 368 |
return "statute"
|
| 369 |
-
elif any(k in prompt_lower for k in ["draft", "write", "generate", "petition", "letter", "contract"]):
|
| 370 |
return "document_creation"
|
| 371 |
elif any(k in prompt_lower for k in ["review", "summarize", "clause", "red flags"]):
|
| 372 |
return "document_analysis"
|
|
@@ -392,104 +260,321 @@ def classify_prompt(prompt):
|
|
| 392 |
return "legal_strategy"
|
| 393 |
return "general_qa"
|
| 394 |
|
| 395 |
-
def
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
else:
|
| 413 |
-
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
else:
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
-
def
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
def summarize_document(files):
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
|
| 434 |
def analyze_document(files):
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
def check_issues(files):
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
)
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
)
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import requests
|
| 3 |
import os
|
| 4 |
import logging
|
| 5 |
from datetime import datetime
|
| 6 |
import pdfplumber
|
| 7 |
+
from docx import Document
|
| 8 |
+
from docx.shared import Pt, Inches
|
| 9 |
+
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
| 10 |
import re
|
| 11 |
+
from datasets import load_dataset, load_from_disk
|
| 12 |
from sentence_transformers import SentenceTransformer
|
| 13 |
import torch
|
| 14 |
import numpy as np
|
| 15 |
import shutil
|
|
|
|
| 16 |
from huggingface_hub import hf_hub_download
|
| 17 |
import pickle
|
| 18 |
import faiss
|
|
|
|
| 20 |
import subprocess
|
| 21 |
from task_processing import process_task_response
|
| 22 |
from gpt_helpers import ask_gpt41_mini
|
| 23 |
+
from retrieval import retrieve_context
|
| 24 |
+
from prompt_builder import build_grok_prompt, build_editor_prompt
|
| 25 |
+
from flask import Flask, request, jsonify, send_from_directory, send_file, Response, stream_with_context
|
| 26 |
+
from werkzeug.utils import secure_filename
|
| 27 |
+
from rank_bm25 import BM25Okapi
|
| 28 |
+
from requests.adapters import HTTPAdapter
|
| 29 |
+
from urllib3.util.retry import Retry
|
| 30 |
+
import json # For safer JSON parsing if needed
|
| 31 |
|
| 32 |
+
app = Flask(__name__) # Renamed from app_flask to app for HF Spaces compatibility
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
# Logging setup
|
| 37 |
+
logger = logging.getLogger("app")
|
| 38 |
+
logging.basicConfig(level=logging.INFO)
|
| 39 |
+
logger.info("✅ Logging initialized. Starting app setup.")
|
| 40 |
+
print("App setup starting...") # Fallback print for early debug
|
| 41 |
|
| 42 |
+
# Hugging Face authentication
|
| 43 |
+
from huggingface_hub import login
|
| 44 |
hf_token = os.environ.get("HF_TOKEN", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
if hf_token:
|
| 46 |
login(hf_token)
|
| 47 |
+
logger.info("✅ Authenticated with Hugging Face token for gated repos.")
|
| 48 |
else:
|
| 49 |
+
logger.warning("HF_TOKEN not set; gated repos may not be accessible.")
|
| 50 |
|
| 51 |
# Check environment variables
|
| 52 |
+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "Missing")
|
| 53 |
+
GOOGLE_SEARCH_API = os.environ.get("GOOGLE_SEARCH_API", "Missing") # CSE ID
|
| 54 |
+
GOOGLE_CUSTOM_SEARCH_API_KEY = os.environ.get("GOOGLE_CUSTOM_SEARCH_API_KEY", "Missing") # API key
|
| 55 |
+
COURT_LISTENER_API_KEY = os.environ.get("Court_Listener_API", "Missing") # Updated to match HF secret name
|
| 56 |
+
if OPENAI_API_KEY == "Missing":
|
| 57 |
+
logger.warning("OPENAI_API_KEY not set; OpenAI features will fail.")
|
| 58 |
+
if GOOGLE_CUSTOM_SEARCH_API_KEY == "Missing" or GOOGLE_SEARCH_API == "Missing":
|
| 59 |
+
logger.warning("Google Search keys not set; search features will fail.")
|
| 60 |
+
if COURT_LISTENER_API_KEY == "Missing":
|
| 61 |
+
logger.warning("Court_Listener_API not set; CourtListener features will fail.")
|
| 62 |
+
logger.info("✅ API keys checked (with warnings if missing).")
|
| 63 |
+
|
| 64 |
+
# Initialize OpenAI client (only if key present)
|
| 65 |
+
openai_client = None
|
| 66 |
+
if OPENAI_API_KEY != "Missing":
|
| 67 |
+
from openai import OpenAI
|
| 68 |
+
openai_client = OpenAI(api_key=OPENAI_API_KEY)
|
| 69 |
+
logger.info("✅ OpenAI client initialized.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
else:
|
| 71 |
+
logger.warning("Skipping OpenAI client init due to missing key.")
|
| 72 |
+
|
| 73 |
+
# Grok API setup
|
| 74 |
+
GROK_API_URL = "https://api.x.ai/v1/chat/completions"
|
| 75 |
+
GROK_API_TOKEN = "xai-fr0jVd7v8jiuxAQko2rpx1ft7DIK0iQkHQTk0RSFubXxdKm6AUgho4hJnlJ2OINlh82AYJ4GELGqLMSb" # From user
|
| 76 |
+
logger.info("✅ Grok API endpoint and token set.")
|
| 77 |
+
|
| 78 |
+
# Global session for retries
|
| 79 |
+
session = requests.Session()
|
| 80 |
+
retries = Retry(total=3, backoff_factor=1, status_forcelist=[422, 503, 504])
|
| 81 |
+
session.mount('https://', HTTPAdapter(max_retries=retries))
|
| 82 |
+
|
| 83 |
+
# Lazy-load CAP dataset to avoid startup issues
|
| 84 |
+
def get_cap_dataset():
|
| 85 |
+
if not hasattr(get_cap_dataset, 'dataset') or get_cap_dataset.dataset is None:
|
| 86 |
+
from datasets import load_from_disk # Lazy import
|
| 87 |
+
LOCAL_PATH = "/data/cap_dataset"
|
| 88 |
+
if os.path.exists(os.path.join(LOCAL_PATH, 'dataset_info.json')):
|
| 89 |
+
try:
|
| 90 |
+
get_cap_dataset.dataset = load_from_disk(LOCAL_PATH)
|
| 91 |
+
logger.info("✅ Lazy-loaded CAP dataset from /data/cap_dataset.")
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logger.error(f"Failed to load CAP dataset: {str(e)}")
|
| 94 |
+
get_cap_dataset.dataset = None
|
| 95 |
+
else:
|
| 96 |
+
logger.error("CAP dataset not found at /data/cap_dataset. Ensure it’s preloaded.")
|
| 97 |
+
get_cap_dataset.dataset = None
|
| 98 |
+
return get_cap_dataset.dataset
|
| 99 |
+
|
| 100 |
+
get_cap_dataset.dataset = None
|
| 101 |
+
logger.info("✅ CAP dataset lazy-loader defined.")
|
| 102 |
+
|
| 103 |
+
# Lazy-compute CID to index mapping for CAP dataset
|
| 104 |
+
def get_cap_id_to_index():
|
| 105 |
+
if not hasattr(get_cap_id_to_index, 'index') or get_cap_id_to_index.index is None:
|
| 106 |
+
import pickle # Lazy import if needed elsewhere, but here it's for logging only
|
| 107 |
+
cap_dataset = get_cap_dataset()
|
| 108 |
+
if cap_dataset is not None:
|
| 109 |
+
get_cap_id_to_index.index = {doc['cid']: i for i, doc in enumerate(cap_dataset) if 'cid' in doc}
|
| 110 |
+
logger.info("✅ Precomputed CAP CID to index mapping.")
|
| 111 |
+
else:
|
| 112 |
+
get_cap_id_to_index.index = {}
|
| 113 |
+
logger.error("CAP dataset not available for index mapping.")
|
| 114 |
+
return get_cap_id_to_index.index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
+
get_cap_id_to_index.index = None
|
| 117 |
+
logger.info("✅ CAP ID-to-index lazy-loader defined.")
|
| 118 |
|
| 119 |
# State dictionary for jurisdiction
|
| 120 |
STATES = {
|
| 121 |
+
"AL": "Alabama", "AK": "Alaska", "AZ": "Arizona", "AR": "Arkansas", "CA": "California",
|
| 122 |
+
"CO": "Colorado", "CT": "Connecticut", "DE": "Delaware", "FL": "Florida", "GA": "Georgia",
|
| 123 |
+
"HI": "Hawaii", "ID": "Idaho", "IL": "Illinois", "IN": "Indiana", "IA": "Iowa",
|
| 124 |
+
"KS": "Kansas", "KY": "Kentucky", "LA": "Louisiana", "ME": "Maine", "MD": "Maryland",
|
| 125 |
+
"MA": "Massachusetts", "MI": "Michigan", "MN": "Minnesota", "MS": "Mississippi", "MO": "Missouri",
|
| 126 |
+
"MT": "Montana", "NE": "Nebraska", "NV": "Nevada", "NH": "New Hampshire", "NJ": "New Jersey",
|
| 127 |
+
"NM": "New Mexico", "NY": "New York", "NC": "North Carolina", "ND": "North Dakota", "OH": "Ohio",
|
| 128 |
+
"OK": "Oklahoma", "OR": "Oregon", "PA": "Pennsylvania", "RI": "Rhode Island", "SC": "South Carolina",
|
| 129 |
+
"SD": "South Dakota", "TN": "Tennessee", "TX": "Texas", "UT": "Utah", "VT": "Vermont",
|
| 130 |
+
"VA": "Virginia", "WA": "Washington", "WV": "West Virginia", "WI": "Wisconsin", "WY": "Wyoming",
|
| 131 |
+
"Federal": "Federal", "All States": "All States", "Other": "Other States"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
}
|
| 133 |
+
logger.info("✅ States dictionary loaded.")
|
| 134 |
|
| 135 |
+
# Verdict Ai api call function (updated for streaming)
|
| 136 |
+
def ask_grok(messages, stream=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
try:
|
| 138 |
+
headers = {
|
| 139 |
+
"Accept": "application/json",
|
| 140 |
+
"Content-Type": "application/json",
|
| 141 |
+
"Authorization": f"Bearer {GROK_API_TOKEN}"
|
| 142 |
+
}
|
| 143 |
payload = {
|
| 144 |
"messages": messages,
|
| 145 |
+
"model": "grok-4-0709",
|
| 146 |
+
"stream": stream,
|
| 147 |
+
"temperature": 0.1,
|
| 148 |
+
"max_tokens": 131072, # High value for long responses
|
| 149 |
+
"search_parameters": {
|
| 150 |
+
"mode": "on"
|
| 151 |
}
|
| 152 |
}
|
| 153 |
+
logger.info(f"Grok payload: {payload}")
|
| 154 |
+
response = requests.post(GROK_API_URL, headers=headers, json=payload, stream=stream)
|
| 155 |
+
logger.info(f"Grok response status: {response.status_code}")
|
| 156 |
response.raise_for_status()
|
| 157 |
+
if stream:
|
| 158 |
+
def stream_gen():
|
| 159 |
+
logger.info("Starting Grok stream...")
|
| 160 |
+
for raw_chunk in response.iter_lines():
|
| 161 |
+
chunk = raw_chunk.decode("utf-8").strip()
|
| 162 |
+
if not chunk:
|
| 163 |
+
continue # Skip empty lines
|
| 164 |
+
chunk_data = chunk.replace("data: ", "")
|
| 165 |
+
logger.info(f"Raw chunk: {chunk_data}")
|
| 166 |
+
if chunk_data == "[DONE]":
|
| 167 |
+
yield "data: [DONE]\n\n"
|
| 168 |
+
break
|
| 169 |
+
try:
|
| 170 |
+
result = json.loads(chunk_data)
|
| 171 |
+
delta = result.get("choices", [{}])[0].get("delta", {})
|
| 172 |
+
content = delta.get("content", "")
|
| 173 |
+
if content:
|
| 174 |
+
yield f'data: {{"chunk": {json.dumps(content)}}}\n\n'
|
| 175 |
+
except Exception as e:
|
| 176 |
+
logger.warning(f"Grok JSON parse error: {e} | chunk_data: {chunk_data}")
|
| 177 |
+
yield f'data: {{"chunk": "[Unrecognized Grok output]"}}\n\n'
|
| 178 |
+
logger.info("Stream ended.")
|
| 179 |
+
return stream_gen()
|
| 180 |
else:
|
| 181 |
+
result = response.json()
|
| 182 |
+
logger.info(f"Grok non-stream result: {result}")
|
| 183 |
+
if "choices" in result and result["choices"] and "message" in result["choices"][0] and "content" in result["choices"][0]["message"]:
|
| 184 |
+
content = result["choices"][0]["message"]["content"]
|
| 185 |
+
if len(content) > 65536:
|
| 186 |
+
content = content[:65536] + "... [Truncated]"
|
| 187 |
+
return content.strip()
|
| 188 |
+
return "[No response]"
|
| 189 |
+
except requests.exceptions.HTTPError as http_err:
|
| 190 |
+
logger.error(f"Grok HTTP error: {http_err}, Response: {response.text if 'response' in locals() else 'N/A'}")
|
| 191 |
+
if stream:
|
| 192 |
+
def error_gen():
|
| 193 |
+
yield f'data: {{"error": "Grok API error: {str(http_err)}"}}\n\n'
|
| 194 |
+
yield "data: [DONE]\n\n"
|
| 195 |
+
return error_gen()
|
| 196 |
+
return "[Grok Error] " + str(http_err)
|
| 197 |
except Exception as e:
|
| 198 |
+
logger.error(f"Grok general error: {type(e).__name__}: {str(e)}")
|
| 199 |
+
if stream:
|
| 200 |
+
def error_gen():
|
| 201 |
+
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
| 202 |
+
yield "data: [DONE]\n\n"
|
| 203 |
+
return error_gen()
|
| 204 |
+
return "[No response]"
|
| 205 |
+
|
| 206 |
+
def extract_text_from_file(file_path):
|
| 207 |
try:
|
| 208 |
+
ext = os.path.splitext(file_path)[1].lower()
|
| 209 |
+
text = ""
|
| 210 |
+
if ext == '.pdf':
|
| 211 |
+
with pdfplumber.open(file_path) as pdf:
|
| 212 |
+
text = "\n".join([page.extract_text() or "" for page in pdf.pages])
|
| 213 |
+
elif ext == '.docx':
|
| 214 |
+
doc = Document(file_path)
|
| 215 |
+
text = "\n".join([para.text for para in doc.paragraphs])
|
| 216 |
+
elif ext == '.txt':
|
| 217 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 218 |
+
text = f.read()
|
| 219 |
+
else:
|
| 220 |
+
text = f"Non-text file uploaded: {os.path.basename(file_path)}. Analyze if image or other."
|
| 221 |
+
logger.info(f"Extracted text length: {len(text)} from {ext} file")
|
| 222 |
+
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
except Exception as e:
|
| 224 |
+
logger.error(f"File extraction error: {str(e)}")
|
| 225 |
return ""
|
| 226 |
|
| 227 |
def classify_prompt(prompt):
|
| 228 |
prompt_lower = prompt.lower()
|
| 229 |
if "summarize" in prompt_lower:
|
| 230 |
return "document_analysis" # Treat summarize as analysis for routing
|
| 231 |
+
if any(k in prompt_lower for k in ["irac", "issue", "rule", "analysis", "conclusion", "brief", "memorandum", "memo"]):
|
| 232 |
return "irac"
|
| 233 |
elif any(k in prompt_lower for k in ["case", "precedent", "law"]):
|
| 234 |
return "case_law"
|
| 235 |
elif any(k in prompt_lower for k in ["statute", "krs"]):
|
| 236 |
return "statute"
|
| 237 |
+
elif any(k in prompt_lower for k in ["draft", "write", "generate", "petition", "letter", "contract", "title opinion"]):
|
| 238 |
return "document_creation"
|
| 239 |
elif any(k in prompt_lower for k in ["review", "summarize", "clause", "red flags"]):
|
| 240 |
return "document_analysis"
|
|
|
|
| 260 |
return "legal_strategy"
|
| 261 |
return "general_qa"
|
| 262 |
|
| 263 |
+
def create_legal_docx(content, jurisdiction, filename):
|
| 264 |
+
doc = Document()
|
| 265 |
+
# Set margins and font
|
| 266 |
+
sections = doc.sections
|
| 267 |
+
for section in sections:
|
| 268 |
+
section.top_margin = Inches(1)
|
| 269 |
+
section.bottom_margin = Inches(1)
|
| 270 |
+
section.left_margin = Inches(1)
|
| 271 |
+
section.right_margin = Inches(1)
|
| 272 |
+
# Case Caption (example placeholder)
|
| 273 |
+
caption = doc.add_paragraph()
|
| 274 |
+
caption.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
| 275 |
+
run = caption.add_run("IN THE [COURT NAME] OF [JURISDICTION]\n")
|
| 276 |
+
run.bold = True
|
| 277 |
+
run.font.size = Pt(12)
|
| 278 |
+
caption.add_run("[Plaintiff] v. [Defendant]\nCase No: [Number]")
|
| 279 |
+
# Add content (assume content has sections marked with # for headings)
|
| 280 |
+
lines = content.split('\n')
|
| 281 |
+
for line in lines:
|
| 282 |
+
if line.startswith('# '):
|
| 283 |
+
heading = doc.add_heading(line[2:], level=1)
|
| 284 |
+
heading.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
| 285 |
+
elif line.startswith('## '):
|
| 286 |
+
doc.add_heading(line[3:], level=2)
|
| 287 |
else:
|
| 288 |
+
p = doc.add_paragraph(line)
|
| 289 |
+
p.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY
|
| 290 |
+
# Signature Block
|
| 291 |
+
doc.add_paragraph("\nRespectfully submitted,")
|
| 292 |
+
sig = doc.add_paragraph("[Attorney Name]\n[Bar Number]\n[Firm]\n[Address]\n[Phone]\n[Email]")
|
| 293 |
+
sig.alignment = WD_ALIGN_PARAGRAPH.LEFT
|
| 294 |
+
# Certificate of Service
|
| 295 |
+
doc.add_heading("CERTIFICATE OF SERVICE", level=1)
|
| 296 |
+
doc.add_paragraph("I hereby certify that a true and correct copy of the foregoing was served on [date] via [method] to:\n[Recipient]")
|
| 297 |
+
# Notary Acknowledgement (if applicable)
|
| 298 |
+
doc.add_heading("NOTARY ACKNOWLEDGEMENT", level=1)
|
| 299 |
+
doc.add_paragraph("[State/County]\nSubscribed and sworn to before me this [date] by [name].\n\nNotary Public")
|
| 300 |
+
doc.save(filename)
|
| 301 |
+
return filename
|
| 302 |
+
|
| 303 |
+
def route_model(messages, task_type, files=None, search_web=False, jurisdiction="KY"):
|
| 304 |
+
logger.info(f"Routing messages, Task: {task_type}, Web Search: {search_web}, Jurisdiction: {jurisdiction}")
|
| 305 |
+
rag_context = ""
|
| 306 |
+
prompt = messages[-1]['content'] # Use last user message as prompt for classification etc.
|
| 307 |
+
if task_type in ["case_law", "irac", "statute"]: # Skip RAG for document_creation/summaries
|
| 308 |
+
cap_dataset = get_cap_dataset()
|
| 309 |
+
if cap_dataset is not None:
|
| 310 |
+
combined_results = retrieve_context(prompt, task_type, jurisdiction)
|
| 311 |
+
# Filter by jurisdiction if specified
|
| 312 |
+
if jurisdiction and jurisdiction != "All States":
|
| 313 |
+
state_name = STATES.get(jurisdiction, "").lower()
|
| 314 |
+
state_code = jurisdiction.lower()
|
| 315 |
+
variants = [state_code, state_name, f"{state_code}.", state_name.replace(" ", "")]
|
| 316 |
+
combined_results = [r for r in combined_results if any(v in (r.get('citation', '') + r.get('name', '') + r.get('snippet', '')).lower() for v in variants)]
|
| 317 |
+
if combined_results:
|
| 318 |
+
rag_context = "Retrieved legal authorities (case law and statutes):\n" + "\n".join(
|
| 319 |
+
[f"{i+1}. [{auth.get('source', 'Unknown')}] {auth['name']}, {auth['citation']}: \"{auth['snippet']}\"" for i, auth in enumerate(combined_results)]
|
| 320 |
+
)
|
| 321 |
+
messages[-1]['content'] = f"{prompt}\n\n{rag_context}"
|
| 322 |
+
if task_type == "document_creation":
|
| 323 |
+
# Reset messages to only current prompt to avoid history accumulation
|
| 324 |
+
prompt = messages[-1]['content']
|
| 325 |
+
draft_messages = [{'role': 'user', 'content': prompt}]
|
| 326 |
+
# Route directly to fine-tuned GPT for document creation
|
| 327 |
+
gpt_response = ask_gpt41_mini(prompt, jurisdiction) # Adjust to use full messages if gpt_helpers supports
|
| 328 |
+
logger.info(f"GPT-4.1-mini response length: {len(gpt_response)} | Content snippet: {gpt_response[:200]}...")
|
| 329 |
+
if not gpt_response.strip():
|
| 330 |
+
logger.warning("Empty response from GPT-4.1-mini; possible content filtering.")
|
| 331 |
+
yield f'data: {{"error": "Empty draft from GPT-4.1-mini - prompt may be filtered. Try rephrasing."}}\n\n'
|
| 332 |
+
yield "data: [DONE]\n\n"
|
| 333 |
+
return
|
| 334 |
+
# Truncate if too long to prevent token issues
|
| 335 |
+
MAX_GPT_LEN = 20000
|
| 336 |
+
if len(gpt_response) > MAX_GPT_LEN:
|
| 337 |
+
gpt_response = gpt_response[:MAX_GPT_LEN] + "\n[Truncated: GPT response too long; refining may be needed.]"
|
| 338 |
+
logger.warning(f"Truncated GPT response to {MAX_GPT_LEN} chars.")
|
| 339 |
+
editor_messages = draft_messages + [{'role': 'assistant', 'content': gpt_response}]
|
| 340 |
+
editor_prompt = build_editor_prompt(prompt, task_type, jurisdiction, gpt_response, rag_context) # But to make contextual, perhaps use full
|
| 341 |
+
editor_messages.append({'role': 'user', 'content': editor_prompt}) # Or append
|
| 342 |
+
# Use non-stream for Grok to avoid streaming issues
|
| 343 |
+
try:
|
| 344 |
+
full_grok_response = ask_grok(editor_messages, stream=False) # CHANGED: Non-stream for reliability
|
| 345 |
+
logger.info(f"Grok polish response length: {len(full_grok_response)} | Snippet: {full_grok_response[:200]}...")
|
| 346 |
+
if not full_grok_response.strip():
|
| 347 |
+
logger.warning("Empty response from Grok; using GPT draft.")
|
| 348 |
+
full_response = gpt_response
|
| 349 |
+
else:
|
| 350 |
+
full_response = full_grok_response
|
| 351 |
+
except Exception as e:
|
| 352 |
+
logger.error(f"Grok non-stream error: {str(e)}. Using GPT draft.")
|
| 353 |
+
full_response = gpt_response
|
| 354 |
+
# Yield as faux stream chunks
|
| 355 |
+
chunks = [full_response[i:i+200] for i in range(0, len(full_response), 200)] # Split for streaming feel
|
| 356 |
+
for part in chunks:
|
| 357 |
+
yield f'data: {{"chunk": {json.dumps(part)}}}\n\n' # Use json.dumps for safe escaping
|
| 358 |
+
# Create doc and send download URL
|
| 359 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 360 |
+
filename = f"/tmp/legal_doc_{timestamp}.docx"
|
| 361 |
+
create_legal_docx(full_response, jurisdiction, filename)
|
| 362 |
+
yield f'data: {{"download_url": "/download/legal_doc_{timestamp}.docx"}}\n\n'
|
| 363 |
+
yield "data: [DONE]\n\n"
|
| 364 |
+
return
|
| 365 |
else:
|
| 366 |
+
try:
|
| 367 |
+
# Build system prompt contextual
|
| 368 |
+
system_content = build_grok_prompt(prompt, task_type, jurisdiction, rag_context) # But since messages have it, prepend if not
|
| 369 |
+
system_content += "\nStick strictly to the provided retrieved context for your response. Do not add information, cases, or statutes not explicitly in the context to avoid hallucinations. If context is insufficient, state so clearly."
|
| 370 |
+
if 'CourtListener' in rag_context:
|
| 371 |
+
system_content += "\nPrioritize CourtListener results for accuracy: Quote key snippets, cite cases, and polish into a structured response (e.g., IRAC format for analysis tasks)."
|
| 372 |
+
if messages[0]['role'] != 'system':
|
| 373 |
+
messages = [{'role': 'system', 'content': system_content}] + messages
|
| 374 |
+
stream_grok = ask_grok(messages, stream=True)
|
| 375 |
+
except Exception as e:
|
| 376 |
+
logger.error(f"Grok failed: {e}. Falling back to GPT-4o.")
|
| 377 |
+
grok_response = ask_gpt4o(messages[-1]['content']) # Fallback, adjust to full if possible
|
| 378 |
+
yield f'data: {{"chunk": "{grok_response}"}}\n\n'
|
| 379 |
+
yield "data: [DONE]\n\n"
|
| 380 |
+
return
|
| 381 |
+
# Task-specific processing
|
| 382 |
+
# For streaming, skip or adapt; here, stream raw
|
| 383 |
+
for chunk in stream_grok:
|
| 384 |
+
yield chunk
|
| 385 |
+
yield "data: [DONE]\n\n"
|
| 386 |
|
| 387 |
+
def ask_gpt4o(prompt):
|
| 388 |
+
try:
|
| 389 |
+
irac_system = "If the task involves legal analysis, polish and organize the output into clear IRAC format. Otherwise, organize appropriately without IRAC."
|
| 390 |
+
response = openai_client.chat.completions.create(
|
| 391 |
+
model="gpt-4o",
|
| 392 |
+
messages=[
|
| 393 |
+
{
|
| 394 |
+
"role": "system",
|
| 395 |
+
"content": (
|
| 396 |
+
f"You are the final editor for a legal research assistant. {irac_system} "
|
| 397 |
+
"Ensure high quote density from retrieved authorities and include relevant facts from the cited cases. "
|
| 398 |
+
"Maintain accurate citations. Do not paraphrase legal holdings when direct quotes are available. "
|
| 399 |
+
"Do not cite or reference any case law, statutes, or authorities that are not explicitly provided in the retrieved context or user input."
|
| 400 |
+
)
|
| 401 |
+
},
|
| 402 |
+
{"role": "user", "content": prompt}
|
| 403 |
+
],
|
| 404 |
+
temperature=0.3,
|
| 405 |
+
max_tokens=65536
|
| 406 |
+
)
|
| 407 |
+
return response.choices[0].message.content
|
| 408 |
+
except Exception as e:
|
| 409 |
+
logger.error(f"GPT-4o error: {str(e)}")
|
| 410 |
+
return f"[GPT-4o Error] {str(e)}"
|
| 411 |
|
| 412 |
def summarize_document(files):
|
| 413 |
+
def gen():
|
| 414 |
+
if files and isinstance(files, list) and files:
|
| 415 |
+
texts = [extract_text_from_file(f) for f in files]
|
| 416 |
+
text = "\n".join(texts)
|
| 417 |
+
if text:
|
| 418 |
+
summary = ask_grok([{"role": "user", "content": f"Summarize the following document(s): {text[:10000]}"}], stream=False) # Explicitly non-stream
|
| 419 |
+
full_response = f"Summary: {summary}"
|
| 420 |
+
chunks = [full_response[i:i+200] for i in range(0, len(full_response), 200)] # Split for streaming feel
|
| 421 |
+
for part in chunks:
|
| 422 |
+
yield f'data: {{"chunk": {json.dumps(part)}}}\n\n'
|
| 423 |
+
yield "data: [DONE]\n\n"
|
| 424 |
+
else:
|
| 425 |
+
yield f'data: {{"chunk": "No text extracted from file."}}\n\n'
|
| 426 |
+
yield "data: [DONE]\n\n"
|
| 427 |
+
else:
|
| 428 |
+
yield f'data: {{"chunk": "Please upload a file to summarize."}}\n\n'
|
| 429 |
+
yield "data: [DONE]\n\n"
|
| 430 |
+
return gen
|
| 431 |
|
| 432 |
def analyze_document(files):
|
| 433 |
+
def gen():
|
| 434 |
+
if files:
|
| 435 |
+
texts = [extract_text_from_file(f) for f in files]
|
| 436 |
+
text = "\n".join(texts)
|
| 437 |
+
if text:
|
| 438 |
+
analysis = ask_grok([{"role": "user", "content": f"Analyze the following document(s) for legal issues, risks, or key clauses: {text[:10000]}"}], stream=False) # Explicitly non-stream
|
| 439 |
+
full_response = f"Analysis: {analysis}"
|
| 440 |
+
chunks = [full_response[i:i+200] for i in range(0, len(full_response), 200)]
|
| 441 |
+
for part in chunks:
|
| 442 |
+
yield f'data: {{"chunk": {json.dumps(part)}}}\n\n'
|
| 443 |
+
yield "data: [DONE]\n\n"
|
| 444 |
+
else:
|
| 445 |
+
yield f'data: {{"chunk": "No text extracted from file."}}\n\n'
|
| 446 |
+
yield "data: [DONE]\n\n"
|
| 447 |
+
else:
|
| 448 |
+
yield f'data: {{"chunk": "No file uploaded for analysis."}}\n\n'
|
| 449 |
+
yield "data: [DONE]\n\n"
|
| 450 |
+
return gen
|
| 451 |
|
| 452 |
def check_issues(files):
|
| 453 |
+
def gen():
|
| 454 |
+
if files:
|
| 455 |
+
texts = [extract_text_from_file(f) for f in files]
|
| 456 |
+
text = "\n".join(texts)
|
| 457 |
+
if text:
|
| 458 |
+
issues = ask_grok([{"role": "user", "content": f"Check for red flags, unusual clauses, or potential issues in this legal document(s) and highlight them: {text[:10000]}"}], stream=False) # Explicitly non-stream
|
| 459 |
+
full_response = f"Highlighted Issues: {issues}"
|
| 460 |
+
chunks = [full_response[i:i+200] for i in range(0, len(full_response), 200)]
|
| 461 |
+
for part in chunks:
|
| 462 |
+
yield f'data: {{"chunk": {json.dumps(part)}}}\n\n'
|
| 463 |
+
yield "data: [DONE]\n\n"
|
| 464 |
+
else:
|
| 465 |
+
yield f'data: {{"chunk": "No text extracted from file."}}\n\n'
|
| 466 |
+
yield "data: [DONE]\n\n"
|
| 467 |
+
else:
|
| 468 |
+
yield f'data: {{"chunk": "No file uploaded to check."}}\n\n'
|
| 469 |
+
yield "data: [DONE]\n\n"
|
| 470 |
+
return gen
|
| 471 |
+
|
| 472 |
+
# Error handlers to always return JSON
|
| 473 |
+
@app.errorhandler(400)
|
| 474 |
+
def bad_request(error):
|
| 475 |
+
return jsonify({'error': 'Bad request'}), 400
|
| 476 |
+
|
| 477 |
+
@app.errorhandler(404)
|
| 478 |
+
def not_found(error):
|
| 479 |
+
return jsonify({'error': 'Not found'}), 404
|
| 480 |
+
|
| 481 |
+
@app.errorhandler(405)
|
| 482 |
+
def method_not_allowed(error):
|
| 483 |
+
return jsonify({'error': 'Method not allowed'}), 405
|
| 484 |
+
|
| 485 |
+
@app.errorhandler(500)
|
| 486 |
+
def internal_error(error):
|
| 487 |
+
return jsonify({'error': 'Internal server error'}), 500
|
| 488 |
+
|
| 489 |
+
@app.errorhandler(Exception)
|
| 490 |
+
def handle_exception(e):
|
| 491 |
+
logger.error(f"Unhandled exception: {str(e)}")
|
| 492 |
+
return jsonify({'error': str(e)}), 500
|
| 493 |
+
|
| 494 |
+
# Flask routes
|
| 495 |
+
@app.route('/')
|
| 496 |
+
def index():
|
| 497 |
+
return send_from_directory('.', 'index.html')
|
| 498 |
+
|
| 499 |
+
@app.route('/api/chat', methods=['POST'])
|
| 500 |
+
def api_chat():
|
| 501 |
+
temp_paths = [] # Initialize here for finally block
|
| 502 |
+
def generate():
|
| 503 |
+
try:
|
| 504 |
+
# Early check for missing data
|
| 505 |
+
if 'payload' not in request.form:
|
| 506 |
+
yield f'data: {{"error": "Missing payload in request"}}\n\n'
|
| 507 |
+
yield "data: [DONE]\n\n"
|
| 508 |
+
return
|
| 509 |
+
payload = json.loads(request.form['payload'])
|
| 510 |
+
messages = payload['messages']
|
| 511 |
+
jurisdiction = payload['jurisdiction']
|
| 512 |
+
irac_mode = payload['irac_mode']
|
| 513 |
+
search_web = payload['web_search']
|
| 514 |
+
uploaded_files = request.files.getlist('file')
|
| 515 |
+
file_texts = []
|
| 516 |
+
if uploaded_files:
|
| 517 |
+
for file in uploaded_files:
|
| 518 |
+
if file.filename:
|
| 519 |
+
filename = secure_filename(file.filename)
|
| 520 |
+
temp_path = os.path.join('/tmp', filename)
|
| 521 |
+
file.save(temp_path)
|
| 522 |
+
file_text = extract_text_from_file(temp_path)
|
| 523 |
+
file_texts.append(file_text)
|
| 524 |
+
temp_paths.append(temp_path)
|
| 525 |
+
file_text_combined = "\n".join(file_texts)
|
| 526 |
+
prompt = messages[-1]['content'] # for classification
|
| 527 |
+
task_type = classify_prompt(prompt)
|
| 528 |
+
if irac_mode:
|
| 529 |
+
task_type = "irac"
|
| 530 |
+
# Append file text to last user message if present
|
| 531 |
+
if file_text_combined:
|
| 532 |
+
messages[-1]['content'] += "\nAttached file content(s): " + file_text_combined[:10000]
|
| 533 |
+
if "summarize" in prompt.lower():
|
| 534 |
+
task_type = "document_analysis"
|
| 535 |
+
gen_func = summarize_document(temp_paths)
|
| 536 |
+
for chunk in gen_func():
|
| 537 |
+
yield chunk
|
| 538 |
+
elif "analyze" in prompt.lower():
|
| 539 |
+
task_type = "document_analysis"
|
| 540 |
+
gen_func = analyze_document(temp_paths)
|
| 541 |
+
for chunk in gen_func():
|
| 542 |
+
yield chunk
|
| 543 |
+
elif "check" in prompt.lower() or "issues" in prompt.lower() or "highlight" in prompt.lower():
|
| 544 |
+
task_type = "document_analysis"
|
| 545 |
+
gen_func = check_issues(temp_paths)
|
| 546 |
+
for chunk in gen_func():
|
| 547 |
+
yield chunk
|
| 548 |
+
elif "generate" in prompt.lower() or "draft" in prompt.lower():
|
| 549 |
+
task_type = "document_creation"
|
| 550 |
+
for line in route_model(messages, task_type, temp_paths, search_web, jurisdiction):
|
| 551 |
+
yield line
|
| 552 |
+
else:
|
| 553 |
+
for line in route_model(messages, task_type, temp_paths, search_web, jurisdiction):
|
| 554 |
+
yield line
|
| 555 |
+
logger.info("Grok response streamed.")
|
| 556 |
+
except Exception as e:
|
| 557 |
+
logger.error(f"Error in /api/chat: {str(e)}")
|
| 558 |
+
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
| 559 |
+
yield "data: [DONE]\n\n"
|
| 560 |
+
finally:
|
| 561 |
+
# Cleanup temp files (no context needed for os.remove)
|
| 562 |
+
for temp_path in temp_paths:
|
| 563 |
+
try:
|
| 564 |
+
os.remove(temp_path)
|
| 565 |
+
except Exception as cleanup_e:
|
| 566 |
+
logger.error(f"Cleanup error: {str(cleanup_e)}")
|
| 567 |
+
return Response(stream_with_context(generate()), mimetype='text/event-stream')
|
| 568 |
+
|
| 569 |
+
@app.route('/download/<filename>', methods=['GET'])
|
| 570 |
+
def download(filename):
|
| 571 |
+
return send_file(os.path.join('/tmp', filename), as_attachment=True)
|
| 572 |
+
|
| 573 |
+
@app.route('/health', methods=['GET'])
|
| 574 |
+
def health():
|
| 575 |
+
return "OK", 200
|
| 576 |
+
|
| 577 |
+
if __name__ == '__main__':
|
| 578 |
+
logger.info("✅ All init complete. Starting Flask app...")
|
| 579 |
+
print("Flask app starting...") # Fallback print
|
| 580 |
+
app.run(host='0.0.0.0', port=7860)
|