Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,11 +9,10 @@ from googleapiclient.discovery import build
|
|
| 9 |
from openai import OpenAI
|
| 10 |
import re
|
| 11 |
from datasets import load_dataset, Dataset
|
| 12 |
-
from
|
| 13 |
import torch
|
| 14 |
import numpy as np
|
| 15 |
import faiss
|
| 16 |
-
import shutil
|
| 17 |
import os
|
| 18 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
| 19 |
|
|
@@ -58,11 +57,8 @@ from huggingface_hub import hf_hub_download
|
|
| 58 |
index_path = hf_hub_download(repo_id="brandonmusic/VerdictAI", filename="knn.index", repo_type="dataset")
|
| 59 |
faiss_index = faiss.read_index(index_path)
|
| 60 |
|
| 61 |
-
# Load
|
| 62 |
-
|
| 63 |
-
dpr_tokenizer = AutoTokenizer.from_pretrained(dpr_model_name)
|
| 64 |
-
dpr_model = AutoModel.from_pretrained(dpr_model_name)
|
| 65 |
-
dpr_model.eval()
|
| 66 |
|
| 67 |
# Simulated case law data (replace with Pyserini integration)
|
| 68 |
CASE_LAW_DB = {
|
|
@@ -134,6 +130,19 @@ STATES = {
|
|
| 134 |
"Other": "Other States"
|
| 135 |
}
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
def route_model(prompt, task_type, files=None, search_web=False, jurisdiction="KY"):
|
| 138 |
logger.info(f"Routing prompt: {prompt}, Task: {task_type}, Web Search: {search_web}, Jurisdiction: {jurisdiction}")
|
| 139 |
|
|
@@ -615,6 +624,7 @@ with gr.Blocks(css=css, theme=theme, title="VerdictAI - Legal Assistant") as app
|
|
| 615 |
with gr.Accordion("📁 My Files / Cases", open=False, elem_classes=["menu-accordion"]) as files_accordion:
|
| 616 |
file_list = gr.State([])
|
| 617 |
uploaded_files = gr.File(file_count="multiple", file_types=[".pdf", "image", "text"], label="📎 Upload Files")
|
|
|
|
| 618 |
def update_file_list(files):
|
| 619 |
if files:
|
| 620 |
file_names = "\n".join(["- " + os.path.basename(f) for f in files])
|
|
@@ -624,7 +634,6 @@ with gr.Blocks(css=css, theme=theme, title="VerdictAI - Legal Assistant") as app
|
|
| 624 |
tag_input = gr.Textbox(placeholder="Tag by case name/client", label="Tag")
|
| 625 |
summarize_btn = gr.Button("AI Summarize", elem_classes=["sidebar-btn"])
|
| 626 |
summarize_btn.click(fn=summarize_document, inputs=[uploaded_files], outputs=summary_output)
|
| 627 |
-
uploaded_files_list = gr.Markdown("Uploaded Files:")
|
| 628 |
summary_output = gr.Textbox(label="Summary Output", show_label=False)
|
| 629 |
|
| 630 |
with gr.Accordion("💬 Saved Chats", open=False, elem_classes=["menu-accordion"]) as chat_accordion:
|
|
|
|
| 9 |
from openai import OpenAI
|
| 10 |
import re
|
| 11 |
from datasets import load_dataset, Dataset
|
| 12 |
+
from sentence_transformers import SentenceTransformer
|
| 13 |
import torch
|
| 14 |
import numpy as np
|
| 15 |
import faiss
|
|
|
|
| 16 |
import os
|
| 17 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
| 18 |
|
|
|
|
| 57 |
index_path = hf_hub_download(repo_id="brandonmusic/VerdictAI", filename="knn.index", repo_type="dataset")
|
| 58 |
faiss_index = faiss.read_index(index_path)
|
| 59 |
|
| 60 |
+
# Load the correct encoder for embeddings (BAAI/bge-base-en-v1.5)
|
| 61 |
+
encoder_model = SentenceTransformer("BAAI/bge-base-en-v1.5")
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# Simulated case law data (replace with Pyserini integration)
|
| 64 |
CASE_LAW_DB = {
|
|
|
|
| 130 |
"Other": "Other States"
|
| 131 |
}
|
| 132 |
|
| 133 |
+
def retrieve_legal_context(query, k=5):
|
| 134 |
+
# Encode query to vector
|
| 135 |
+
query_embedding = encoder_model.encode(query, normalize_embeddings=True)
|
| 136 |
+
query_embedding = np.array([query_embedding]).astype('float32') # FAISS expects float32
|
| 137 |
+
# Search FAISS index
|
| 138 |
+
distances, indices = faiss_index.search(query_embedding, k)
|
| 139 |
+
# Fetch texts from dataset (assuming 'text' column; adjust if needed)
|
| 140 |
+
contexts = []
|
| 141 |
+
for idx in indices[0]:
|
| 142 |
+
if idx >= 0 and idx < len(cap_dataset):
|
| 143 |
+
contexts.append(cap_dataset[int(idx)]['text'])
|
| 144 |
+
return "\n\n".join(contexts) # Or format as needed
|
| 145 |
+
|
| 146 |
def route_model(prompt, task_type, files=None, search_web=False, jurisdiction="KY"):
|
| 147 |
logger.info(f"Routing prompt: {prompt}, Task: {task_type}, Web Search: {search_web}, Jurisdiction: {jurisdiction}")
|
| 148 |
|
|
|
|
| 624 |
with gr.Accordion("📁 My Files / Cases", open=False, elem_classes=["menu-accordion"]) as files_accordion:
|
| 625 |
file_list = gr.State([])
|
| 626 |
uploaded_files = gr.File(file_count="multiple", file_types=[".pdf", "image", "text"], label="📎 Upload Files")
|
| 627 |
+
uploaded_files_list = gr.Markdown("Uploaded Files:")
|
| 628 |
def update_file_list(files):
|
| 629 |
if files:
|
| 630 |
file_names = "\n".join(["- " + os.path.basename(f) for f in files])
|
|
|
|
| 634 |
tag_input = gr.Textbox(placeholder="Tag by case name/client", label="Tag")
|
| 635 |
summarize_btn = gr.Button("AI Summarize", elem_classes=["sidebar-btn"])
|
| 636 |
summarize_btn.click(fn=summarize_document, inputs=[uploaded_files], outputs=summary_output)
|
|
|
|
| 637 |
summary_output = gr.Textbox(label="Summary Output", show_label=False)
|
| 638 |
|
| 639 |
with gr.Accordion("💬 Saved Chats", open=False, elem_classes=["menu-accordion"]) as chat_accordion:
|