tdocaibot / search.py
albhu's picture
Update search.py
a284200 verified
import time
import streamlit as st
import pandas as pd
import os
from dotenv import load_dotenv
import search # Import the search module
from transformers import AutoTokenizer, AutoModelForCausalLM
from docx import Document
from pdfminer.high_level import extract_text
from dataclasses import dataclass
from typing import List
from tqdm import tqdm
import re
from sklearn.feature_extraction.text import TfidfVectorizer
load_dotenv()
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
EMBEDDING_SEG_LEN = 1500
EMBEDDING_MODEL = "gpt-4"
EMBEDDING_CTX_LENGTH = 8191
EMBEDDING_ENCODING = "cl100k_base"
ENCODING = "gpt2"
@dataclass
class Paragraph:
page_num: int
paragraph_num: int
content: str
def read_pdf_pdfminer(file_path) -> List[Paragraph]:
text = extract_text(file_path).replace('\n', ' ').strip()
paragraphs = batched(text, EMBEDDING_SEG_LEN)
paragraphs_objs = []
paragraph_num = 1
for p in paragraphs:
para = Paragraph(0, paragraph_num, p)
paragraphs_objs.append(para)
paragraph_num += 1
return paragraphs_objs
def read_docx(file) -> List[Paragraph]:
doc = Document(file)
paragraphs = []
for paragraph_num, paragraph in enumerate(doc.paragraphs, start=1):
content = paragraph.text.strip()
if content:
para = Paragraph(1, paragraph_num, content)
paragraphs.append(para)
return paragraphs
def count_tokens(text, tokenizer):
return len(tokenizer.encode(text))
def batched(iterable, n):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx : min(ndx + n, l)]
def compute_doc_embeddings(df, tokenizer):
embeddings = {}
for index, row in tqdm(df.iterrows(), total=df.shape[0]):
doc = row["content"]
doc_embedding = get_embedding(doc, tokenizer)
embeddings[index] = doc_embedding
return embeddings
def enhanced_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5):
paragraphs = [para for para in document.split("\n") if para]
scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords if keyword in para.lower()]) for para in paragraphs]
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
relevant_paragraphs = [paragraphs[i] for i in top_indices]
return " ".join(relevant_paragraphs)
def targeted_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5):
paragraphs = [para for para in document.split("\n") if para]
scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords]) for para in paragraphs]
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
relevant_paragraphs = [paragraphs[i] for i in top_indices]
return " ".join(relevant_paragraphs)
def extract_page_and_clause_references(paragraph: str) -> str:
page_matches = re.findall(r'Page (\d+)', paragraph)
clause_matches = re.findall(r'Clause (\d+\.\d+)', paragraph)
page_ref = f"Page {page_matches[0]}" if page_matches else ""
clause_ref = f"Clause {clause_matches[0]}" if clause_matches else ""
return f"({page_ref}, {clause_ref})".strip(", ")
def refine_answer_based_on_question(question: str, answer: str) -> str:
if "Does the agreement contain" in question:
if "not" in answer or "No" in answer:
refined_answer = f"No, the agreement does not contain {answer}"
else:
refined_answer = f"Yes, the agreement contains {answer}"
else:
refined_answer = answer
return refined_answer
def answer_query_with_context(question: str, df: pd.DataFrame, tokenizer, model, top_n_paragraphs: int = 5) -> str:
question_words = set(question.split())
priority_keywords = ["duration", "term", "period", "month", "year", "day", "week", "agreement", "obligation", "effective date"]
df['relevance_score'] = df['content'].apply(lambda x: len(question_words.intersection(set(x.split()))) + sum([x.lower().count(pk) for pk in priority_keywords]))
most_relevant_paragraphs = df.sort_values(by='relevance_score', ascending=False).iloc[:top_n_paragraphs]['content'].tolist()
context = "\n\n".join(most_relevant_paragraphs)
prompt = f"Question: {question}\n\nContext: {context}\n\nAnswer:"
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
outputs = model.generate(inputs, max_length=600)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
references = extract_page_and_clause_references(context)
answer = refine_answer_based_on_question(question, answer) + " " + references
return answer
def get_embedding(text, tokenizer):
try:
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
outputs = model(**inputs)
embedding = outputs.last_hidden_state
except Exception as e:
print("Error obtaining embedding:", e)
embedding = []
return embedding
def save_as_pdf(conversation):
pdf_filename = "conversation.pdf"
c = canvas.Canvas(pdf_filename, pagesize=letter)
c.drawString(100, 750, "Conversation:")
y_position = 730
for q, a in conversation:
c.drawString(120, y_position, f"Q: {q}")
c.drawString(120, y_position - 20, f"A: {a}")
y_position -= 40
c.save()
st.markdown(f"Download [PDF](./{pdf_filename})")
def save_as_docx(conversation):
doc = Document()
doc.add_heading('Conversation', 0)
for q, a in conversation:
doc.add_paragraph(f'Q: {q}')
doc.add_paragraph(f'A: {a}')
doc_filename = "conversation.docx"
doc.save(doc_filename)
st.markdown(f"Download [DOCX](./{doc_filename})")
def save_as_xlsx(conversation):
df = pd.DataFrame(conversation, columns=["Question", "Answer"])
xlsx_filename = "conversation.xlsx"
df.to_excel(xlsx_filename, index=False)
st.markdown(f"Download [XLSX](./{xlsx_filename})")
def save_as_txt(conversation):
txt_filename = "conversation.txt"
with open(txt_filename, "w") as txt_file:
for q, a in conversation:
txt_file.write(f"Q: {q}\nA: {a}\n\n")
st.markdown(f"Download [TXT](./{txt_filename})")
def main():
st.markdown('<h1>Ask anything from Legal Texts</h1><p style="font-size: 12; color: gray;"></p>', unsafe_allow_html=True)
st.markdown("<h2>Upload documents</h2>", unsafe_allow_html=True)
uploaded_files = st.file_uploader("Upload one or more documents", type=['pdf', 'docx'], accept_multiple_files=True)
question = st.text_input("Ask a question based on the documents", key="question_input")
progress = st.progress(0)
for i in range(100):
progress.progress(i + 1)
time.sleep(0.01)
if uploaded_files:
df = pd.DataFrame(columns=["page_num", "paragraph_num", "content", "tokens"])
for uploaded_file in uploaded_files:
paragraphs = read_pdf_pdfminer(uploaded_file) if uploaded_file.type == "application/pdf" else read_docx(uploaded_file)
temp_df = pd.DataFrame(
[(p.page_num, p.paragraph_num, p.content, count_tokens(p.content, tokenizer))
for p in paragraphs],
columns=["page_num", "paragraph_num", "content", "tokens"]
)
df = pd.concat([df, temp_df], ignore_index=True)
if "interactions" not in st.session_state:
st.session_state["interactions"] = []
answer = ""
if question != st.session_state.get("last_question", ""):
st.text("Searching...")
answer = answer_query_with_context(question, df, tokenizer, model)
st.session_state["interactions"].append((question, answer))
st.write(answer)
st.markdown("### Interaction History")
for q, a in st.session_state["interactions"]:
st.write(f"**Q:** {q}\n\n**A:** {a}")
st.session_state["last_question"] = question
st.markdown("<h2>Sample paragraphs</h2>", unsafe_allow_html=True)
sample_size = min(len(df), 5)
st.dataframe(df.sample(n=sample_size))
if st.button("Save as PDF"):
save_as_pdf(st.session_state["interactions"])
if st.button("Save as DOCX"):
save_as_docx(st.session_state["interactions"])
if st.button("Save as XLSX"):
save_as_xlsx(st.session_state["interactions"])
if st.button("Save as TXT"):
save_as_txt(st.session_state["interactions"])
else:
st.markdown("<h2>Please upload a document to proceed.</h2>", unsafe_allow_html=True)
if __name__ == "__main__":
main()