brandonmusic commited on
Commit
79d76ff
·
verified ·
1 Parent(s): cb5ed63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -9,11 +9,10 @@ from googleapiclient.discovery import build
9
  from openai import OpenAI
10
  import re
11
  from datasets import load_dataset, Dataset
12
- from transformers import AutoTokenizer, AutoModel
13
  import torch
14
  import numpy as np
15
  import faiss
16
- import shutil
17
  import os
18
  os.environ["HF_HOME"] = "/data/.huggingface"
19
 
@@ -58,11 +57,8 @@ from huggingface_hub import hf_hub_download
58
  index_path = hf_hub_download(repo_id="brandonmusic/VerdictAI", filename="knn.index", repo_type="dataset")
59
  faiss_index = faiss.read_index(index_path)
60
 
61
- # Load LegalBERT-DPR encoder
62
- dpr_model_name = "jhu-clsp/LegalBERT-DPR-CLERC-ft"
63
- dpr_tokenizer = AutoTokenizer.from_pretrained(dpr_model_name)
64
- dpr_model = AutoModel.from_pretrained(dpr_model_name)
65
- dpr_model.eval()
66
 
67
  # Simulated case law data (replace with Pyserini integration)
68
  CASE_LAW_DB = {
@@ -134,6 +130,19 @@ STATES = {
134
  "Other": "Other States"
135
  }
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def route_model(prompt, task_type, files=None, search_web=False, jurisdiction="KY"):
138
  logger.info(f"Routing prompt: {prompt}, Task: {task_type}, Web Search: {search_web}, Jurisdiction: {jurisdiction}")
139
 
@@ -615,6 +624,7 @@ with gr.Blocks(css=css, theme=theme, title="VerdictAI - Legal Assistant") as app
615
  with gr.Accordion("📁 My Files / Cases", open=False, elem_classes=["menu-accordion"]) as files_accordion:
616
  file_list = gr.State([])
617
  uploaded_files = gr.File(file_count="multiple", file_types=[".pdf", "image", "text"], label="📎 Upload Files")
 
618
  def update_file_list(files):
619
  if files:
620
  file_names = "\n".join(["- " + os.path.basename(f) for f in files])
@@ -624,7 +634,6 @@ with gr.Blocks(css=css, theme=theme, title="VerdictAI - Legal Assistant") as app
624
  tag_input = gr.Textbox(placeholder="Tag by case name/client", label="Tag")
625
  summarize_btn = gr.Button("AI Summarize", elem_classes=["sidebar-btn"])
626
  summarize_btn.click(fn=summarize_document, inputs=[uploaded_files], outputs=summary_output)
627
- uploaded_files_list = gr.Markdown("Uploaded Files:")
628
  summary_output = gr.Textbox(label="Summary Output", show_label=False)
629
 
630
  with gr.Accordion("💬 Saved Chats", open=False, elem_classes=["menu-accordion"]) as chat_accordion:
 
9
  from openai import OpenAI
10
  import re
11
  from datasets import load_dataset, Dataset
12
+ from sentence_transformers import SentenceTransformer
13
  import torch
14
  import numpy as np
15
  import faiss
 
16
  import os
17
  os.environ["HF_HOME"] = "/data/.huggingface"
18
 
 
57
  index_path = hf_hub_download(repo_id="brandonmusic/VerdictAI", filename="knn.index", repo_type="dataset")
58
  faiss_index = faiss.read_index(index_path)
59
 
60
+ # Load the correct encoder for embeddings (BAAI/bge-base-en-v1.5)
61
+ encoder_model = SentenceTransformer("BAAI/bge-base-en-v1.5")
 
 
 
62
 
63
  # Simulated case law data (replace with Pyserini integration)
64
  CASE_LAW_DB = {
 
130
  "Other": "Other States"
131
  }
132
 
133
+ def retrieve_legal_context(query, k=5):
134
+ # Encode query to vector
135
+ query_embedding = encoder_model.encode(query, normalize_embeddings=True)
136
+ query_embedding = np.array([query_embedding]).astype('float32') # FAISS expects float32
137
+ # Search FAISS index
138
+ distances, indices = faiss_index.search(query_embedding, k)
139
+ # Fetch texts from dataset (assuming 'text' column; adjust if needed)
140
+ contexts = []
141
+ for idx in indices[0]:
142
+ if idx >= 0 and idx < len(cap_dataset):
143
+ contexts.append(cap_dataset[int(idx)]['text'])
144
+ return "\n\n".join(contexts) # Or format as needed
145
+
146
  def route_model(prompt, task_type, files=None, search_web=False, jurisdiction="KY"):
147
  logger.info(f"Routing prompt: {prompt}, Task: {task_type}, Web Search: {search_web}, Jurisdiction: {jurisdiction}")
148
 
 
624
  with gr.Accordion("📁 My Files / Cases", open=False, elem_classes=["menu-accordion"]) as files_accordion:
625
  file_list = gr.State([])
626
  uploaded_files = gr.File(file_count="multiple", file_types=[".pdf", "image", "text"], label="📎 Upload Files")
627
+ uploaded_files_list = gr.Markdown("Uploaded Files:")
628
  def update_file_list(files):
629
  if files:
630
  file_names = "\n".join(["- " + os.path.basename(f) for f in files])
 
634
  tag_input = gr.Textbox(placeholder="Tag by case name/client", label="Tag")
635
  summarize_btn = gr.Button("AI Summarize", elem_classes=["sidebar-btn"])
636
  summarize_btn.click(fn=summarize_document, inputs=[uploaded_files], outputs=summary_output)
 
637
  summary_output = gr.Textbox(label="Summary Output", show_label=False)
638
 
639
  with gr.Accordion("💬 Saved Chats", open=False, elem_classes=["menu-accordion"]) as chat_accordion: