|
import streamlit as st |
|
import shelve |
|
import docx2txt |
|
import PyPDF2 |
|
import time |
|
import nltk |
|
|
|
import re |
|
import os |
|
import requests |
|
from dotenv import load_dotenv |
|
|
|
|
|
import torch |
|
from sentence_transformers import SentenceTransformer, util |
|
import nltk |
|
|
|
nltk.download('punkt') |
|
import hashlib |
|
|
|
from nltk import sent_tokenize |
|
nltk.download('punkt') |
|
|
|
nltk.download('punkt_tab') |
|
|
|
from transformers import LEDTokenizer, LEDForConditionalGeneration |
|
import torch |
|
|
|
st.set_page_config(page_title="Legal Document Summarizer", layout="wide") |
|
|
|
st.title("π Legal Document Summarizer (Upload)") |
|
|
|
USER_AVATAR = "π€" |
|
BOT_AVATAR = "π€" |
|
|
|
|
|
def load_chat_history(): |
|
with shelve.open("chat_history") as db: |
|
return db.get("messages", []) |
|
|
|
|
|
def save_chat_history(messages): |
|
with shelve.open("chat_history") as db: |
|
db["messages"] = messages |
|
|
|
|
|
def limit_text(text, word_limit=500): |
|
words = text.split() |
|
return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "") |
|
|
|
|
|
|
|
|
|
|
|
def clean_text(text): |
|
|
|
text = text.replace('\r\n', ' ').replace('\n', ' ') |
|
text = re.sub(r'\s+', ' ', text) |
|
|
|
|
|
text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE) |
|
|
|
|
|
text = re.sub(r'[_]{5,}', '', text) |
|
text = re.sub(r'[-]{5,}', '', text) |
|
|
|
|
|
text = re.sub(r'[.]{4,}', '', text) |
|
|
|
|
|
text = text.strip() |
|
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
HF_API_TOKEN = os.getenv("HF_API_TOKEN") |
|
|
|
|
|
def classify_zero_shot_hfapi(text, labels): |
|
if not HF_API_TOKEN: |
|
return "β Hugging Face token not found." |
|
|
|
headers = { |
|
"Authorization": f"Bearer {HF_API_TOKEN}" |
|
} |
|
|
|
payload = { |
|
"inputs": text, |
|
"parameters": { |
|
"candidate_labels": labels |
|
} |
|
} |
|
|
|
response = requests.post( |
|
"https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-1", |
|
headers=headers, |
|
json=payload |
|
) |
|
|
|
if response.status_code != 200: |
|
return f"β Error from HF API: {response.status_code} - {response.text}" |
|
|
|
result = response.json() |
|
return result["labels"][0] |
|
|
|
|
|
|
|
SECTION_LABELS = ["Facts", "Arguments", "Judgment", "Other"] |
|
|
|
|
|
def classify_chunk(text): |
|
return classify_zero_shot_hfapi(text, SECTION_LABELS) |
|
|
|
|
|
|
|
|
|
def section_by_zero_shot(text): |
|
sections = {"Facts": "", "Arguments": "", "Judgment": "", "Other": ""} |
|
sentences = sent_tokenize(text) |
|
chunk = "" |
|
|
|
for i, sent in enumerate(sentences): |
|
chunk += sent + " " |
|
if (i + 1) % 3 == 0 or i == len(sentences) - 1: |
|
label = classify_chunk(chunk.strip()) |
|
print(f"π Chunk: {chunk[:60]}...\nπ Predicted Label: {label}") |
|
|
|
label = label.capitalize() |
|
if label not in sections: |
|
label = "Other" |
|
sections[label] += chunk + "\n" |
|
chunk = "" |
|
|
|
return sections |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_text(file): |
|
if file.name.endswith(".pdf"): |
|
reader = PyPDF2.PdfReader(file) |
|
full_text = "\n".join(page.extract_text() or "" for page in reader.pages) |
|
elif file.name.endswith(".docx"): |
|
full_text = docx2txt.process(file) |
|
elif file.name.endswith(".txt"): |
|
full_text = file.read().decode("utf-8") |
|
else: |
|
return "Unsupported file type." |
|
|
|
return full_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_legalbert(): |
|
return SentenceTransformer("nlpaueb/legal-bert-base-uncased") |
|
|
|
|
|
legalbert_model = load_legalbert() |
|
|
|
@st.cache_resource |
|
def load_led(): |
|
tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384") |
|
model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384") |
|
return tokenizer, model |
|
|
|
tokenizer_led, model_led = load_led() |
|
|
|
|
|
def legalbert_extractive_summary(text, top_ratio=0.2): |
|
sentences = sent_tokenize(text) |
|
top_k = max(3, int(len(sentences) * top_ratio)) |
|
|
|
if len(sentences) <= top_k: |
|
return text |
|
|
|
|
|
sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True) |
|
doc_embedding = torch.mean(sentence_embeddings, dim=0) |
|
cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0] |
|
top_results = torch.topk(cosine_scores, k=top_k) |
|
|
|
|
|
selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())] |
|
return " ".join(selected_sentences) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def led_abstractive_summary(text, max_length=512, min_length=100): |
|
inputs = tokenizer_led( |
|
text, return_tensors="pt", padding="max_length", |
|
truncation=True, max_length=4096 |
|
) |
|
global_attention_mask = torch.zeros_like(inputs["input_ids"]) |
|
global_attention_mask[:, 0] = 1 |
|
|
|
outputs = model_led.generate( |
|
inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
global_attention_mask=global_attention_mask, |
|
max_length=max_length, |
|
min_length=min_length, |
|
length_penalty=2.0, |
|
num_beams=4 |
|
) |
|
return tokenizer_led.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
def hybrid_summary_by_section(text, top_ratio=0.8): |
|
cleaned_text = clean_text(text) |
|
sections = section_by_zero_shot(cleaned_text) |
|
|
|
summary_parts = [] |
|
for name, content in sections.items(): |
|
if content.strip(): |
|
|
|
sentences = sent_tokenize(content) |
|
top_k = max(3, int(len(sentences) * top_ratio)) |
|
|
|
|
|
extractive = legalbert_extractive_summary(content, 0.8) |
|
|
|
|
|
abstractive = led_abstractive_summary(extractive) |
|
|
|
|
|
hybrid = f"π **Extractive Summary:**\n{extractive}\n\nπ **Abstractive Summary:**\n{abstractive}" |
|
summary_parts.append(f"### π {name} Section:\n{clean_text(hybrid)}") |
|
|
|
return "\n\n".join(summary_parts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = load_chat_history() |
|
|
|
|
|
if "last_uploaded" not in st.session_state: |
|
st.session_state.last_uploaded = None |
|
|
|
|
|
with st.sidebar: |
|
st.subheader("βοΈ Options") |
|
if st.button("Delete Chat History"): |
|
st.session_state.messages = [] |
|
st.session_state.last_uploaded = None |
|
save_chat_history([]) |
|
|
|
|
|
def display_with_typing_effect(text, speed=0.005): |
|
placeholder = st.empty() |
|
displayed_text = "" |
|
for char in text: |
|
displayed_text += char |
|
placeholder.markdown(displayed_text) |
|
time.sleep(speed) |
|
return displayed_text |
|
|
|
|
|
for message in st.session_state.messages: |
|
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR |
|
with st.chat_message(message["role"], avatar=avatar): |
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
prompt = st.chat_input("Type a message...") |
|
|
|
|
|
|
|
|
|
|
|
with st.container(): |
|
st.subheader("π Upload a Legal Document") |
|
uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"]) |
|
reprocess_btn = st.button("π Reprocess Last Uploaded File") |
|
|
|
|
|
|
|
def get_file_hash(file): |
|
file.seek(0) |
|
content = file.read() |
|
file.seek(0) |
|
return hashlib.md5(content).hexdigest() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if uploaded_file: |
|
file_hash = get_file_hash(uploaded_file) |
|
|
|
|
|
if file_hash != st.session_state.get("last_uploaded_hash") or reprocess_btn: |
|
raw_text = extract_text(uploaded_file) |
|
summary_text = hybrid_summary_by_section(raw_text) |
|
|
|
st.session_state.messages.append({ |
|
"role": "user", |
|
"content": f"π€ Uploaded **{uploaded_file.name}**" |
|
}) |
|
|
|
with st.chat_message("assistant", avatar=BOT_AVATAR): |
|
preview_text = f"π§Ύ **Hybrid Summary of {uploaded_file.name}:**\n\n{summary_text}" |
|
display_with_typing_effect(clean_text(preview_text), speed=0.000001) |
|
|
|
st.session_state.messages.append({ |
|
"role": "assistant", |
|
"content": preview_text |
|
}) |
|
|
|
|
|
if not reprocess_btn: |
|
st.session_state.last_uploaded_hash = file_hash |
|
|
|
save_chat_history(st.session_state.messages) |
|
st.rerun() |
|
|
|
|
|
|
|
if prompt: |
|
raw_text = prompt |
|
summary_text = hybrid_summary_by_section(raw_text) |
|
|
|
st.session_state.messages.append({ |
|
"role": "user", |
|
"content": prompt |
|
}) |
|
|
|
with st.chat_message("assistant", avatar=BOT_AVATAR): |
|
bot_response = f"π **Hybrid Summary of your text:**\n\n{summary_text}" |
|
display_with_typing_effect(clean_text(bot_response), speed=0.000005) |
|
|
|
st.session_state.messages.append({ |
|
"role": "assistant", |
|
"content": bot_response |
|
}) |
|
|
|
save_chat_history(st.session_state.messages) |
|
st.rerun() |
|
|