Spaces:
Sleeping
Sleeping
File size: 4,827 Bytes
84a1d68 7d84c8a c3a7290 1a5d2d1 7d84c8a 1a5d2d1 33ba695 a327243 85a8211 33ba695 1a5d2d1 7d84c8a 157dba7 33ba695 157dba7 c9d41c5 1a5d2d1 7d84c8a 157dba7 d37f862 ea66760 0633d3c 56b2c14 33ba695 96c5163 33ba695 238efca 33ba695 7d84c8a 54073f5 d37f862 7d84c8a e316790 a327243 499e542 7d84c8a edcbe37 7d84c8a c3a7290 7d84c8a a0479e0 33ba695 56b2c14 33ba695 4dc9e48 7d84c8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import nest_asyncio
import streamlit as st
import os
from flashrank import Ranker, RerankRequest
from qdrant_client import QdrantClient
from llama_index.llms.groq import Groq
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import PyPDF2
nest_asyncio.apply()
os.environ["HF_TOKEN"] = st.secrets["HF_TOKEN"]
groq_token = st.secrets["groq_token"]
st.set_page_config(
layout="wide"
)
# default llamaindex llm and embendding model selection
@st.cache_resource(show_spinner=False)
def llamaindex_default():
Settings.llm = Groq(model="llama-3.1-8b-instant", api_key=groq_token)
Settings.embed_model = HuggingFaceEmbedding(
model_name="law-ai/InLegalBERT", trust_remote_code=True
)
llamaindex_default()
# set up qdrant client
@st.cache_resource(show_spinner=False)
def load_index():
qdrant_client = QdrantClient(
path="."
)
vector_store = QdrantVectorStore(
client=qdrant_client, collection_name="legal_v1"
)
return VectorStoreIndex.from_vector_store(vector_store=vector_store)
index = load_index()
# reranker selection in the sidebar
with st.sidebar:
selected_reranker = st.selectbox(
"Select a reranker",
("default", "ms-marco-MiniLM-L-12-v2", "rank-T5-flan")
)
if selected_reranker == "default":
ranker = Ranker()
else:
ranker = Ranker(model_name=selected_reranker, cache_dir=".")
# Calculate individual weightages with sidebar slider
dense_weightage = st.slider("Dense Weightage", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
sparse_weightage = 1 - dense_weightage
st.write("dense weight: ",dense_weightage)
st.write("sparse weight: ",sparse_weightage)
num_k = st.number_input(
"Enter k",
min_value=1,
max_value=10,
value=10
)
@st.cache_resource(show_spinner=False)
def load_retriver():
dense_retriever = VectorIndexRetriever(
index=index,
similarity_top_k=num_k
)
sparse_retriever = BM25Retriever.from_persist_dir("./sparse_retriever")
sparse_retriever.similarity_top_k = num_k
retriever = QueryFusionRetriever(
[
dense_retriever,
sparse_retriever,
],
num_queries=1,
use_async=False,
retriever_weights=[dense_weightage, sparse_weightage],
similarity_top_k=num_k,
mode="relative_score",
verbose=True,
)
return retriever
retriever = load_retriver()
def extract_pdf_content(pdf_file_path):
with open(pdf_file_path, 'rb') as pdf_file:
pdf_reader = PyPDF2.PdfReader(pdf_file)
text = ""
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
text += page.extract_text()
return text
#prompt template for summarization
template = """
Please summarize the following legal document and provide the summary in the specified format. The output should directly follow the format without any introductory text.
**Document:**
{document_content}
**Format:**
**Case:** [Case Number]
**Petitioner:** [Petitioner's Name]
**Respondent:** [Respondent's Name]
**Judge:** [Judge's Name]
**Order Date:** [Order Date]
**Summary:**
- **Background:** [Brief description of the case background]
- **Allegations:** [Summary of the allegations made in the case]
- **Investigation:** [Key findings from the investigation]
- **Court's Decision:** [Summary of the court's decision and any conditions imposed]
"""
st.title("Legal Documents Hybrid+Reranker Search")
query = st.text_input("Search through documents by keyword", value="")
search_btn = st.button("Search")
if search_btn and query:
nodes = retriever.retrieve(query)
passages = []
for node in nodes:
passage = {
"id": node.node_id,
"text": node.text,
"meta": node.metadata
}
passages.append(passage)
rerankrequest = RerankRequest(query=query, passages=passages)
results = ranker.rerank(rerankrequest)
for node in results:
st.write("File Name: ", node["meta"].get("file_name"))
st.write("reranking score: ", node["score"])
st.write("node id", node["id"])
with st.expander("See Summary"):
text = extract_pdf_content("./documents/" + node["meta"].get("file_name"))
formatted_template = template.format(document_content=text)
summary = Settings.llm.complete(formatted_template)
st.markdown(summary)
st.write("---")
|