|
import time |
|
import streamlit as st |
|
import pandas as pd |
|
import os |
|
from dotenv import load_dotenv |
|
import search |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from docx import Document |
|
from pdfminer.high_level import extract_text |
|
from dataclasses import dataclass |
|
from typing import List |
|
from tqdm import tqdm |
|
import re |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
|
|
load_dotenv() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True) |
|
|
|
EMBEDDING_SEG_LEN = 1500 |
|
EMBEDDING_MODEL = "gpt-4" |
|
|
|
EMBEDDING_CTX_LENGTH = 8191 |
|
EMBEDDING_ENCODING = "cl100k_base" |
|
ENCODING = "gpt2" |
|
|
|
@dataclass |
|
class Paragraph: |
|
page_num: int |
|
paragraph_num: int |
|
content: str |
|
|
|
def read_pdf_pdfminer(file_path) -> List[Paragraph]: |
|
text = extract_text(file_path).replace('\n', ' ').strip() |
|
paragraphs = batched(text, EMBEDDING_SEG_LEN) |
|
paragraphs_objs = [] |
|
paragraph_num = 1 |
|
for p in paragraphs: |
|
para = Paragraph(0, paragraph_num, p) |
|
paragraphs_objs.append(para) |
|
paragraph_num += 1 |
|
return paragraphs_objs |
|
|
|
def read_docx(file) -> List[Paragraph]: |
|
doc = Document(file) |
|
paragraphs = [] |
|
for paragraph_num, paragraph in enumerate(doc.paragraphs, start=1): |
|
content = paragraph.text.strip() |
|
if content: |
|
para = Paragraph(1, paragraph_num, content) |
|
paragraphs.append(para) |
|
return paragraphs |
|
|
|
def count_tokens(text, tokenizer): |
|
return len(tokenizer.encode(text)) |
|
|
|
def batched(iterable, n): |
|
l = len(iterable) |
|
for ndx in range(0, l, n): |
|
yield iterable[ndx : min(ndx + n, l)] |
|
|
|
def compute_doc_embeddings(df, tokenizer): |
|
embeddings = {} |
|
for index, row in tqdm(df.iterrows(), total=df.shape[0]): |
|
doc = row["content"] |
|
doc_embedding = get_embedding(doc, tokenizer) |
|
embeddings[index] = doc_embedding |
|
return embeddings |
|
|
|
def enhanced_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5): |
|
paragraphs = [para for para in document.split("\n") if para] |
|
scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords if keyword in para.lower()]) for para in paragraphs] |
|
|
|
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n] |
|
relevant_paragraphs = [paragraphs[i] for i in top_indices] |
|
|
|
return " ".join(relevant_paragraphs) |
|
|
|
def targeted_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5): |
|
paragraphs = [para for para in document.split("\n") if para] |
|
scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords]) for para in paragraphs] |
|
|
|
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n] |
|
relevant_paragraphs = [paragraphs[i] for i in top_indices] |
|
|
|
return " ".join(relevant_paragraphs) |
|
|
|
|
|
def extract_page_and_clause_references(paragraph: str) -> str: |
|
page_matches = re.findall(r'Page (\d+)', paragraph) |
|
clause_matches = re.findall(r'Clause (\d+\.\d+)', paragraph) |
|
|
|
page_ref = f"Page {page_matches[0]}" if page_matches else "" |
|
clause_ref = f"Clause {clause_matches[0]}" if clause_matches else "" |
|
|
|
return f"({page_ref}, {clause_ref})".strip(", ") |
|
|
|
def refine_answer_based_on_question(question: str, answer: str) -> str: |
|
if "Does the agreement contain" in question: |
|
if "not" in answer or "No" in answer: |
|
refined_answer = f"No, the agreement does not contain {answer}" |
|
else: |
|
refined_answer = f"Yes, the agreement contains {answer}" |
|
else: |
|
refined_answer = answer |
|
|
|
return refined_answer |
|
|
|
def answer_query_with_context(question: str, df: pd.DataFrame, tokenizer, model, top_n_paragraphs: int = 5) -> str: |
|
question_words = set(question.split()) |
|
|
|
priority_keywords = ["duration", "term", "period", "month", "year", "day", "week", "agreement", "obligation", "effective date"] |
|
|
|
df['relevance_score'] = df['content'].apply(lambda x: len(question_words.intersection(set(x.split()))) + sum([x.lower().count(pk) for pk in priority_keywords])) |
|
|
|
most_relevant_paragraphs = df.sort_values(by='relevance_score', ascending=False).iloc[:top_n_paragraphs]['content'].tolist() |
|
|
|
context = "\n\n".join(most_relevant_paragraphs) |
|
prompt = f"Question: {question}\n\nContext: {context}\n\nAnswer:" |
|
|
|
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True) |
|
outputs = model.generate(inputs, max_length=600) |
|
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
references = extract_page_and_clause_references(context) |
|
answer = refine_answer_based_on_question(question, answer) + " " + references |
|
|
|
return answer |
|
|
|
def get_embedding(text, tokenizer): |
|
try: |
|
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True) |
|
outputs = model(**inputs) |
|
embedding = outputs.last_hidden_state |
|
except Exception as e: |
|
print("Error obtaining embedding:", e) |
|
embedding = [] |
|
return embedding |
|
|
|
def save_as_pdf(conversation): |
|
pdf_filename = "conversation.pdf" |
|
c = canvas.Canvas(pdf_filename, pagesize=letter) |
|
|
|
c.drawString(100, 750, "Conversation:") |
|
y_position = 730 |
|
for q, a in conversation: |
|
c.drawString(120, y_position, f"Q: {q}") |
|
c.drawString(120, y_position - 20, f"A: {a}") |
|
y_position -= 40 |
|
|
|
c.save() |
|
|
|
st.markdown(f"Download [PDF](./{pdf_filename})") |
|
|
|
def save_as_docx(conversation): |
|
doc = Document() |
|
doc.add_heading('Conversation', 0) |
|
|
|
for q, a in conversation: |
|
doc.add_paragraph(f'Q: {q}') |
|
doc.add_paragraph(f'A: {a}') |
|
|
|
doc_filename = "conversation.docx" |
|
doc.save(doc_filename) |
|
|
|
st.markdown(f"Download [DOCX](./{doc_filename})") |
|
|
|
def save_as_xlsx(conversation): |
|
df = pd.DataFrame(conversation, columns=["Question", "Answer"]) |
|
xlsx_filename = "conversation.xlsx" |
|
df.to_excel(xlsx_filename, index=False) |
|
|
|
st.markdown(f"Download [XLSX](./{xlsx_filename})") |
|
|
|
def save_as_txt(conversation): |
|
txt_filename = "conversation.txt" |
|
with open(txt_filename, "w") as txt_file: |
|
for q, a in conversation: |
|
txt_file.write(f"Q: {q}\nA: {a}\n\n") |
|
|
|
st.markdown(f"Download [TXT](./{txt_filename})") |
|
|
|
def main(): |
|
st.markdown('<h1>Ask anything from Legal Texts</h1><p style="font-size: 12; color: gray;"></p>', unsafe_allow_html=True) |
|
st.markdown("<h2>Upload documents</h2>", unsafe_allow_html=True) |
|
|
|
uploaded_files = st.file_uploader("Upload one or more documents", type=['pdf', 'docx'], accept_multiple_files=True) |
|
question = st.text_input("Ask a question based on the documents", key="question_input") |
|
|
|
progress = st.progress(0) |
|
for i in range(100): |
|
progress.progress(i + 1) |
|
time.sleep(0.01) |
|
|
|
if uploaded_files: |
|
df = pd.DataFrame(columns=["page_num", "paragraph_num", "content", "tokens"]) |
|
for uploaded_file in uploaded_files: |
|
paragraphs = read_pdf_pdfminer(uploaded_file) if uploaded_file.type == "application/pdf" else read_docx(uploaded_file) |
|
temp_df = pd.DataFrame( |
|
[(p.page_num, p.paragraph_num, p.content, count_tokens(p.content, tokenizer)) |
|
for p in paragraphs], |
|
columns=["page_num", "paragraph_num", "content", "tokens"] |
|
) |
|
df = pd.concat([df, temp_df], ignore_index=True) |
|
|
|
if "interactions" not in st.session_state: |
|
st.session_state["interactions"] = [] |
|
|
|
answer = "" |
|
if question != st.session_state.get("last_question", ""): |
|
st.text("Searching...") |
|
answer = answer_query_with_context(question, df, tokenizer, model) |
|
st.session_state["interactions"].append((question, answer)) |
|
st.write(answer) |
|
|
|
st.markdown("### Interaction History") |
|
for q, a in st.session_state["interactions"]: |
|
st.write(f"**Q:** {q}\n\n**A:** {a}") |
|
|
|
st.session_state["last_question"] = question |
|
|
|
st.markdown("<h2>Sample paragraphs</h2>", unsafe_allow_html=True) |
|
sample_size = min(len(df), 5) |
|
st.dataframe(df.sample(n=sample_size)) |
|
|
|
if st.button("Save as PDF"): |
|
save_as_pdf(st.session_state["interactions"]) |
|
if st.button("Save as DOCX"): |
|
save_as_docx(st.session_state["interactions"]) |
|
if st.button("Save as XLSX"): |
|
save_as_xlsx(st.session_state["interactions"]) |
|
if st.button("Save as TXT"): |
|
save_as_txt(st.session_state["interactions"]) |
|
|
|
|
|
else: |
|
st.markdown("<h2>Please upload a document to proceed.</h2>", unsafe_allow_html=True) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|