brandonmusic commited on
Commit
a4e3c6d
·
verified ·
1 Parent(s): 34b551a

Update retrieval.py

Browse files
Files changed (1) hide show
  1. retrieval.py +35 -39
retrieval.py CHANGED
@@ -1,9 +1,3 @@
1
-
2
- This updated `app.py` script includes the jurisdiction mapping in `route_model` (e.g., defaults to "KY" if not specified, maps to court codes like 'ky kyctapp'). It's fully copy-pastable—replace your existing file.
3
-
4
- ### Updated retrieval.py Script
5
-
6
- ```python
7
  # retrieval.py
8
  # Uncommented all sections for full functionality.
9
  # Removed duplicated Flask code at the end (copy-paste error).
@@ -13,12 +7,14 @@ This updated `app.py` script includes the jurisdiction mapping in `route_model`
13
  # Integrated google_search.
14
  import os
15
  import logging
16
- import requests # Lightweight, keep at top
17
- import pickle # Lightweight
18
  import shutil
19
  from huggingface_hub import hf_hub_download, snapshot_download
20
- from openai import OpenAI # Client init is fast, but usage in functions
21
- import time # For sleep after download
 
 
22
  # Logging setup (lightweight)
23
  logger = logging.getLogger("retrieval")
24
  logging.basicConfig(level=logging.INFO)
@@ -26,23 +22,23 @@ logging.basicConfig(level=logging.INFO)
26
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
27
  openai_client = OpenAI(api_key=OPENAI_API_KEY)
28
  GOOGLE_CUSTOM_SEARCH_API_KEY = os.environ.get("GOOGLE_CUSTOM_SEARCH_API_KEY", "Missing")
29
- GOOGLE_SEARCH_API = os.environ.get("GOOGLE_SEARCH_API", "Missing") # CSE ID
30
  hf_token = os.environ.get("HF_TOKEN", "")
31
  COURT_LISTENER_API_KEY = os.environ.get("Court_Listener_API", "Missing") # Updated to match HF secret name
32
  # Lazy placeholders
33
  centroid_vectors = None
34
  encoder = None
35
- municipal_encoder = None # Separate for potential dim mismatch
36
  _cluster_cache = {}
37
  municipal_faiss_index = None
38
- cap_faiss_index = None # New for CAP FAISS
39
  municipal_metadata = None
40
  municipal_texts = None
41
  bm25_municipal = None
42
  # Lazy-load CAP dataset
43
  def get_cap_dataset():
44
  if not hasattr(get_cap_dataset, 'dataset') or get_cap_dataset.dataset is None:
45
- from datasets import load_from_disk # Lazy import
46
  LOCAL_PATH = "/data/cap_dataset"
47
  if os.path.exists(os.path.join(LOCAL_PATH, 'dataset_info.json')):
48
  try:
@@ -72,7 +68,7 @@ def load_encoder():
72
  global encoder
73
  if encoder is not None:
74
  return
75
- from sentence_transformers import SentenceTransformer # Lazy import
76
  encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
77
  logger.info("🚀 Lazy-loaded CAP Encoder: SentenceTransformer (all-mpnet-base-v2 for 768 dim match)")
78
  logger.info(f"CAP encoder dimension: {encoder.get_sentence_embedding_dimension()}")
@@ -80,7 +76,7 @@ def load_municipal_encoder():
80
  global municipal_encoder
81
  if municipal_encoder is not None:
82
  return
83
- from sentence_transformers import SentenceTransformer # Lazy import
84
  municipal_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
85
  logger.info("🚀 Lazy-loaded Municipal Encoder: SentenceTransformer (all-MiniLM-L6-v2 for 384 dim match)")
86
  logger.info(f"Municipal encoder dimension: {municipal_encoder.get_sentence_embedding_dimension()}")
@@ -88,7 +84,7 @@ def load_cap_faiss_index():
88
  global cap_faiss_index
89
  if cap_faiss_index is not None:
90
  return
91
- import faiss # Lazy import
92
  cap_index_path = "/data/knn.index"
93
  if not os.path.exists(cap_index_path):
94
  try:
@@ -96,7 +92,7 @@ def load_cap_faiss_index():
96
  logger.info("✅ Downloaded missing CAP FAISS index from HF.")
97
  except Exception as e:
98
  logger.error(f"❌ Failed to download CAP FAISS index: {str(e)}. CAP semantic search disabled.")
99
- cap_faiss_index = "loaded" # Marker to avoid reload
100
  return
101
  try:
102
  cap_faiss_index = faiss.read_index(cap_index_path)
@@ -104,12 +100,12 @@ def load_cap_faiss_index():
104
  logger.info(f"CAP FAISS index dimension: {cap_faiss_index.d}")
105
  except Exception as e:
106
  logger.error(f"❌ Failed to load CAP FAISS index: {str(e)}. CAP semantic search disabled.")
107
- cap_faiss_index = "loaded" # Marker
108
  def load_municipal_faiss_index():
109
  global municipal_faiss_index
110
  if municipal_faiss_index is not None:
111
  return
112
- import faiss # Lazy import
113
  municipal_index_path = "/data/municipal.index"
114
  if os.path.exists(municipal_index_path):
115
  municipal_faiss_index = faiss.read_index(municipal_index_path)
@@ -117,7 +113,7 @@ def load_municipal_faiss_index():
117
  logger.info(f"Municipal FAISS index dimension: {municipal_faiss_index.d}")
118
  else:
119
  logger.error("municipal.index not found in /data. Hybrid search for municipal data disabled.")
120
- municipal_faiss_index = "loaded" # Marker to avoid reload
121
  def load_municipal_metadata():
122
  global municipal_metadata
123
  if municipal_metadata is not None:
@@ -146,14 +142,14 @@ def load_bm25_municipal():
146
  global bm25_municipal
147
  if bm25_municipal is not None:
148
  return
149
- from rank_bm25 import BM25Okapi # Lazy import
150
  bm25_municipal_path = "/data/bm25_municipal.pkl"
151
  if os.path.exists(bm25_municipal_path):
152
  with open(bm25_municipal_path, 'rb') as f:
153
  bm25_municipal = pickle.load(f)
154
  logger.info("✅ Lazy-loaded cached BM25 for municipal hybrid search.")
155
  else:
156
- load_municipal_texts() # Ensure texts loaded
157
  if not municipal_texts:
158
  logger.error("Cannot build BM25 index because municipal texts are not loaded.")
159
  bm25_municipal = "build_failed"
@@ -164,18 +160,18 @@ def load_bm25_municipal():
164
  pickle.dump(bm25_municipal, f)
165
  logger.info("✅ Built and cached BM25 for municipal hybrid search.")
166
  def semantic_search(query, top_k=5, min_score=0.1):
167
- import numpy as np # Lazy import
168
- from sklearn.feature_extraction.text import TfidfVectorizer # Lazy import
169
- from sklearn.metrics.pairwise import cosine_similarity # Lazy import
170
  logger.info(f"Search query sent to FAISS (CAP): {query}")
171
  load_cap_faiss_index()
172
- if cap_faiss_index == "loaded": # Marker for failed load
173
  logger.warning("CAP FAISS index not available. Returning empty results.")
174
  return []
175
  load_encoder()
176
  query_vec = encoder.encode(query, normalize_embeddings=True)
177
  query_vec = np.array(query_vec).astype('float32').reshape(1, -1)
178
- import faiss # Ensure imported
179
  try:
180
  if query_vec.shape[1] != cap_faiss_index.d:
181
  raise AssertionError(f"Dimension mismatch: query {query_vec.shape[1]} != index {cap_faiss_index.d}")
@@ -218,7 +214,7 @@ def semantic_search(query, top_k=5, min_score=0.1):
218
  logger.info(f"FAISS (CAP) returned {len(results)} docs")
219
  return [{k: v for k, v in r.items() if k != 'score'} for r in results]
220
  def municipal_search(query, top_k=5, min_score=0.1):
221
- import numpy as np # Lazy import
222
  load_municipal_faiss_index()
223
  load_municipal_encoder()
224
  load_bm25_municipal()
@@ -279,23 +275,23 @@ def municipal_search(query, top_k=5, min_score=0.1):
279
  return [{k: v for k, v in r.items() if k != 'score'} for r in results[:top_k]]
280
  def retrieve_context(original_prompt, task_type, jurisdiction="ky"):
281
  query = query_rewrite(original_prompt, task_type)
282
-
283
  cap_results = semantic_search(query)
284
  municipal_results = municipal_search(query)
285
-
286
  combined_results = cap_results + municipal_results
287
-
288
  if not combined_results:
289
  logger.warning(f"No context found for query: {query} (task: {task_type}) — attempting web fallback.")
290
- fallback_query = f"{query} site:law.cornell.edu OR site:justia.com OR site:findlaw.com"
291
  web_data = google_search(fallback_query, GOOGLE_CUSTOM_SEARCH_API_KEY, GOOGLE_SEARCH_API)
292
  if web_data != "No search results found.":
293
  combined_results = [{"source": "Web", "name": "Web Fallback", "citation": "Various Sources", "snippet": web_data[:700]}]
294
-
295
  # Added: Call CourtListener for case_law or irac tasks and append to combined_results
296
  if task_type in ["case_law", "irac"] and COURT_LISTENER_API_KEY != "Missing":
297
  logger.info("Calling CourtListener API...")
298
- courtlistener_results = search_courtlistener(query, jurisdiction.lower(), '2021-01-01', '2025-08-11')
299
  if courtlistener_results and 'results' in courtlistener_results:
300
  logger.info(f"CourtListener returned {len(courtlistener_results['results'])} results")
301
  for result in courtlistener_results['results']:
@@ -323,7 +319,7 @@ def query_rewrite(original_prompt, task_type):
323
  temperature=0.3,
324
  max_tokens=50
325
  )
326
- rewritten = response.choices[0].message.content.strip().replace('"', '') # Stripped quotes per Gemini
327
  logger.info(f"Original prompt: {original_prompt[:100]}... -> Rewritten query: {rewritten}")
328
  return rewritten
329
  except Exception as e:
@@ -335,7 +331,7 @@ def google_search(query, GOOGLE_CUSTOM_SEARCH_API_KEY, GOOGLE_SEARCH_API):
335
  return "Google Custom Search API key not set."
336
  if GOOGLE_SEARCH_API == "Missing":
337
  return "Google CSE ID not set."
338
- from googleapiclient.discovery import build # Lazy import
339
  service = build("customsearch", "v1", developerKey=GOOGLE_CUSTOM_SEARCH_API_KEY, cache_discovery=False)
340
  res = service.cse().list(q=query, cx=GOOGLE_SEARCH_API).execute()
341
  if "items" in res:
@@ -357,7 +353,7 @@ def ground_statutes(response, jurisdiction, GOOGLE_CUSTOM_SEARCH_API_KEY, GOOGLE
357
  # In practice, parse response for statute mentions, search, and replace/inject quotes
358
  try:
359
  # Example: Find statute mentions and ground
360
- statute_mentions = re.findall(r'KRS \d+\.\d+', response) # Simple regex for KRS
361
  if statute_mentions:
362
  for stat in statute_mentions:
363
  search_result = google_search(f"{stat} {jurisdiction} statute text", GOOGLE_CUSTOM_SEARCH_API_KEY, GOOGLE_SEARCH_API)
@@ -368,7 +364,7 @@ def ground_statutes(response, jurisdiction, GOOGLE_CUSTOM_SEARCH_API_KEY, GOOGLE
368
  logger.error(f"Grounding error: {str(e)}")
369
  return response
370
  # New function for CourtListener search (added)
371
- def search_courtlistener(query, jurisdiction='ky', date_min='2021-01-01', date_max='2025-08-11'):
372
  """
373
  Searches CourtListener for cases matching the query.
374
  Returns JSON data for RAG processing.
 
 
 
 
 
 
 
1
  # retrieval.py
2
  # Uncommented all sections for full functionality.
3
  # Removed duplicated Flask code at the end (copy-paste error).
 
7
  # Integrated google_search.
8
  import os
9
  import logging
10
+ import requests # Lightweight, keep at top
11
+ import pickle # Lightweight
12
  import shutil
13
  from huggingface_hub import hf_hub_download, snapshot_download
14
+ from openai import OpenAI # Client init is fast, but usage in functions
15
+ import time # For sleep after download
16
+ import re
17
+ import datetime
18
  # Logging setup (lightweight)
19
  logger = logging.getLogger("retrieval")
20
  logging.basicConfig(level=logging.INFO)
 
22
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
23
  openai_client = OpenAI(api_key=OPENAI_API_KEY)
24
  GOOGLE_CUSTOM_SEARCH_API_KEY = os.environ.get("GOOGLE_CUSTOM_SEARCH_API_KEY", "Missing")
25
+ GOOGLE_SEARCH_API = os.environ.get("GOOGLE_SEARCH_API", "Missing") # CSE ID
26
  hf_token = os.environ.get("HF_TOKEN", "")
27
  COURT_LISTENER_API_KEY = os.environ.get("Court_Listener_API", "Missing") # Updated to match HF secret name
28
  # Lazy placeholders
29
  centroid_vectors = None
30
  encoder = None
31
+ municipal_encoder = None # Separate for potential dim mismatch
32
  _cluster_cache = {}
33
  municipal_faiss_index = None
34
+ cap_faiss_index = None # New for CAP FAISS
35
  municipal_metadata = None
36
  municipal_texts = None
37
  bm25_municipal = None
38
  # Lazy-load CAP dataset
39
  def get_cap_dataset():
40
  if not hasattr(get_cap_dataset, 'dataset') or get_cap_dataset.dataset is None:
41
+ from datasets import load_from_disk # Lazy import
42
  LOCAL_PATH = "/data/cap_dataset"
43
  if os.path.exists(os.path.join(LOCAL_PATH, 'dataset_info.json')):
44
  try:
 
68
  global encoder
69
  if encoder is not None:
70
  return
71
+ from sentence_transformers import SentenceTransformer # Lazy import
72
  encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
73
  logger.info("🚀 Lazy-loaded CAP Encoder: SentenceTransformer (all-mpnet-base-v2 for 768 dim match)")
74
  logger.info(f"CAP encoder dimension: {encoder.get_sentence_embedding_dimension()}")
 
76
  global municipal_encoder
77
  if municipal_encoder is not None:
78
  return
79
+ from sentence_transformers import SentenceTransformer # Lazy import
80
  municipal_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
81
  logger.info("🚀 Lazy-loaded Municipal Encoder: SentenceTransformer (all-MiniLM-L6-v2 for 384 dim match)")
82
  logger.info(f"Municipal encoder dimension: {municipal_encoder.get_sentence_embedding_dimension()}")
 
84
  global cap_faiss_index
85
  if cap_faiss_index is not None:
86
  return
87
+ import faiss # Lazy import
88
  cap_index_path = "/data/knn.index"
89
  if not os.path.exists(cap_index_path):
90
  try:
 
92
  logger.info("✅ Downloaded missing CAP FAISS index from HF.")
93
  except Exception as e:
94
  logger.error(f"❌ Failed to download CAP FAISS index: {str(e)}. CAP semantic search disabled.")
95
+ cap_faiss_index = "loaded" # Marker to avoid reload
96
  return
97
  try:
98
  cap_faiss_index = faiss.read_index(cap_index_path)
 
100
  logger.info(f"CAP FAISS index dimension: {cap_faiss_index.d}")
101
  except Exception as e:
102
  logger.error(f"❌ Failed to load CAP FAISS index: {str(e)}. CAP semantic search disabled.")
103
+ cap_faiss_index = "loaded" # Marker
104
  def load_municipal_faiss_index():
105
  global municipal_faiss_index
106
  if municipal_faiss_index is not None:
107
  return
108
+ import faiss # Lazy import
109
  municipal_index_path = "/data/municipal.index"
110
  if os.path.exists(municipal_index_path):
111
  municipal_faiss_index = faiss.read_index(municipal_index_path)
 
113
  logger.info(f"Municipal FAISS index dimension: {municipal_faiss_index.d}")
114
  else:
115
  logger.error("municipal.index not found in /data. Hybrid search for municipal data disabled.")
116
+ municipal_faiss_index = "loaded" # Marker to avoid reload
117
  def load_municipal_metadata():
118
  global municipal_metadata
119
  if municipal_metadata is not None:
 
142
  global bm25_municipal
143
  if bm25_municipal is not None:
144
  return
145
+ from rank_bm25 import BM25Okapi # Lazy import
146
  bm25_municipal_path = "/data/bm25_municipal.pkl"
147
  if os.path.exists(bm25_municipal_path):
148
  with open(bm25_municipal_path, 'rb') as f:
149
  bm25_municipal = pickle.load(f)
150
  logger.info("✅ Lazy-loaded cached BM25 for municipal hybrid search.")
151
  else:
152
+ load_municipal_texts() # Ensure texts loaded
153
  if not municipal_texts:
154
  logger.error("Cannot build BM25 index because municipal texts are not loaded.")
155
  bm25_municipal = "build_failed"
 
160
  pickle.dump(bm25_municipal, f)
161
  logger.info("✅ Built and cached BM25 for municipal hybrid search.")
162
  def semantic_search(query, top_k=5, min_score=0.1):
163
+ import numpy as np # Lazy import
164
+ from sklearn.feature_extraction.text import TfidfVectorizer # Lazy import
165
+ from sklearn.metrics.pairwise import cosine_similarity # Lazy import
166
  logger.info(f"Search query sent to FAISS (CAP): {query}")
167
  load_cap_faiss_index()
168
+ if cap_faiss_index == "loaded": # Marker for failed load
169
  logger.warning("CAP FAISS index not available. Returning empty results.")
170
  return []
171
  load_encoder()
172
  query_vec = encoder.encode(query, normalize_embeddings=True)
173
  query_vec = np.array(query_vec).astype('float32').reshape(1, -1)
174
+ import faiss # Ensure imported
175
  try:
176
  if query_vec.shape[1] != cap_faiss_index.d:
177
  raise AssertionError(f"Dimension mismatch: query {query_vec.shape[1]} != index {cap_faiss_index.d}")
 
214
  logger.info(f"FAISS (CAP) returned {len(results)} docs")
215
  return [{k: v for k, v in r.items() if k != 'score'} for r in results]
216
  def municipal_search(query, top_k=5, min_score=0.1):
217
+ import numpy as np # Lazy import
218
  load_municipal_faiss_index()
219
  load_municipal_encoder()
220
  load_bm25_municipal()
 
275
  return [{k: v for k, v in r.items() if k != 'score'} for r in results[:top_k]]
276
  def retrieve_context(original_prompt, task_type, jurisdiction="ky"):
277
  query = query_rewrite(original_prompt, task_type)
278
+
279
  cap_results = semantic_search(query)
280
  municipal_results = municipal_search(query)
281
+
282
  combined_results = cap_results + municipal_results
283
+
284
  if not combined_results:
285
  logger.warning(f"No context found for query: {query} (task: {task_type}) — attempting web fallback.")
286
+ fallback_query = f"{query} site:law.cornell.edu OR site:justia.com OR site:findlaw.com OR site:findlaw.com"
287
  web_data = google_search(fallback_query, GOOGLE_CUSTOM_SEARCH_API_KEY, GOOGLE_SEARCH_API)
288
  if web_data != "No search results found.":
289
  combined_results = [{"source": "Web", "name": "Web Fallback", "citation": "Various Sources", "snippet": web_data[:700]}]
290
+
291
  # Added: Call CourtListener for case_law or irac tasks and append to combined_results
292
  if task_type in ["case_law", "irac"] and COURT_LISTENER_API_KEY != "Missing":
293
  logger.info("Calling CourtListener API...")
294
+ courtlistener_results = search_courtlistener(query, jurisdiction.lower(), '2021-01-01', datetime.datetime.today().date().isoformat())
295
  if courtlistener_results and 'results' in courtlistener_results:
296
  logger.info(f"CourtListener returned {len(courtlistener_results['results'])} results")
297
  for result in courtlistener_results['results']:
 
319
  temperature=0.3,
320
  max_tokens=50
321
  )
322
+ rewritten = response.choices[0].message.content.strip().replace('"', '') # Stripped quotes per Gemini
323
  logger.info(f"Original prompt: {original_prompt[:100]}... -> Rewritten query: {rewritten}")
324
  return rewritten
325
  except Exception as e:
 
331
  return "Google Custom Search API key not set."
332
  if GOOGLE_SEARCH_API == "Missing":
333
  return "Google CSE ID not set."
334
+ from googleapiclient.discovery import build # Lazy import
335
  service = build("customsearch", "v1", developerKey=GOOGLE_CUSTOM_SEARCH_API_KEY, cache_discovery=False)
336
  res = service.cse().list(q=query, cx=GOOGLE_SEARCH_API).execute()
337
  if "items" in res:
 
353
  # In practice, parse response for statute mentions, search, and replace/inject quotes
354
  try:
355
  # Example: Find statute mentions and ground
356
+ statute_mentions = re.findall(r'KRS \d+\.\d+', response) # Simple regex for KRS
357
  if statute_mentions:
358
  for stat in statute_mentions:
359
  search_result = google_search(f"{stat} {jurisdiction} statute text", GOOGLE_CUSTOM_SEARCH_API_KEY, GOOGLE_SEARCH_API)
 
364
  logger.error(f"Grounding error: {str(e)}")
365
  return response
366
  # New function for CourtListener search (added)
367
+ def search_courtlistener(query, jurisdiction='ky', date_min='2021-01-01', date_max=datetime.datetime.today().date().isoformat()):
368
  """
369
  Searches CourtListener for cases matching the query.
370
  Returns JSON data for RAG processing.