File size: 4,827 Bytes
84a1d68
7d84c8a
 
 
c3a7290
1a5d2d1
7d84c8a
 
 
 
 
1a5d2d1
 
33ba695
a327243
85a8211
33ba695
 
 
 
1a5d2d1
7d84c8a
157dba7
 
33ba695
157dba7
 
 
c9d41c5
1a5d2d1
7d84c8a
157dba7
 
 
 
 
 
 
 
 
d37f862
ea66760
 
0633d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56b2c14
33ba695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96c5163
33ba695
 
 
 
 
238efca
33ba695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d84c8a
54073f5
d37f862
7d84c8a
e316790
a327243
499e542
7d84c8a
 
 
edcbe37
7d84c8a
 
 
 
 
 
 
 
c3a7290
7d84c8a
a0479e0
 
 
33ba695
56b2c14
33ba695
 
4dc9e48
7d84c8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import nest_asyncio
import streamlit as st
import os
from flashrank import Ranker, RerankRequest
from qdrant_client import QdrantClient
from llama_index.llms.groq import Groq
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import PyPDF2
nest_asyncio.apply()
os.environ["HF_TOKEN"] = st.secrets["HF_TOKEN"]
groq_token = st.secrets["groq_token"]
st.set_page_config(
    layout="wide"
)

# default llamaindex llm and embendding model selection
@st.cache_resource(show_spinner=False)
def llamaindex_default():
    Settings.llm = Groq(model="llama-3.1-8b-instant", api_key=groq_token)
    Settings.embed_model = HuggingFaceEmbedding(
        model_name="law-ai/InLegalBERT", trust_remote_code=True
    )
llamaindex_default()

# set up qdrant client
@st.cache_resource(show_spinner=False)
def load_index():
    qdrant_client = QdrantClient(
        path="."
    )
    vector_store = QdrantVectorStore(
        client=qdrant_client, collection_name="legal_v1"
    )
    return VectorStoreIndex.from_vector_store(vector_store=vector_store)

index = load_index()

# reranker selection in the sidebar
with st.sidebar:
    selected_reranker = st.selectbox(
        "Select a reranker",
        ("default", "ms-marco-MiniLM-L-12-v2", "rank-T5-flan")
    )

    if selected_reranker == "default":
        ranker = Ranker()
    else:
        ranker = Ranker(model_name=selected_reranker, cache_dir=".")

    # Calculate individual weightages with sidebar slider 
    dense_weightage = st.slider("Dense Weightage", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
    sparse_weightage = 1 - dense_weightage
    st.write("dense weight: ",dense_weightage)
    st.write("sparse weight: ",sparse_weightage)
    num_k = st.number_input(
        "Enter k",
        min_value=1,
        max_value=10,
        value=10
    )

@st.cache_resource(show_spinner=False)
def load_retriver():
    dense_retriever = VectorIndexRetriever(
        index=index,
        similarity_top_k=num_k
    )
    sparse_retriever = BM25Retriever.from_persist_dir("./sparse_retriever")
    sparse_retriever.similarity_top_k = num_k

    retriever = QueryFusionRetriever(
    [
        dense_retriever,
        sparse_retriever,
    ],
    num_queries=1,
    use_async=False,
    retriever_weights=[dense_weightage, sparse_weightage],
    similarity_top_k=num_k,
    mode="relative_score",
    verbose=True,
    )
    return retriever

retriever = load_retriver()

def extract_pdf_content(pdf_file_path):
    with open(pdf_file_path, 'rb') as pdf_file:
        pdf_reader = PyPDF2.PdfReader(pdf_file)
        text = ""
        for page_num in range(len(pdf_reader.pages)):
            page = pdf_reader.pages[page_num]
            text += page.extract_text()
        return text

#prompt template for summarization
template = """
Please summarize the following legal document and provide the summary in the specified format. The output should directly follow the format without any introductory text.
**Document:**
{document_content}

**Format:**

**Case:** [Case Number]

**Petitioner:** [Petitioner's Name]

**Respondent:** [Respondent's Name]

**Judge:** [Judge's Name]

**Order Date:** [Order Date]

**Summary:**
- **Background:** [Brief description of the case background]
- **Allegations:** [Summary of the allegations made in the case]
- **Investigation:** [Key findings from the investigation]
- **Court's Decision:** [Summary of the court's decision and any conditions imposed]
"""

st.title("Legal Documents Hybrid+Reranker Search")

query = st.text_input("Search through documents by keyword", value="")

search_btn = st.button("Search")

if search_btn and query:
    nodes = retriever.retrieve(query)
    passages = []
    for node in nodes:
        passage = {
            "id": node.node_id,
            "text": node.text,
            "meta": node.metadata
        }
        passages.append(passage)
    rerankrequest = RerankRequest(query=query, passages=passages)
    results = ranker.rerank(rerankrequest)  

    for node in results:
        st.write("File Name: ", node["meta"].get("file_name"))
        st.write("reranking score: ", node["score"])
        st.write("node id", node["id"])
        with st.expander("See Summary"):
            text = extract_pdf_content("./documents/" + node["meta"].get("file_name"))
            formatted_template = template.format(document_content=text)
            summary = Settings.llm.complete(formatted_template)
            st.markdown(summary)
        st.write("---")