weknow / app.py
legaltechgc's picture
Update app.py
cfdb962 verified
raw
history blame
5.63 kB
import streamlit as st
import faiss
import numpy as np
from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer
from sentence_transformers import SentenceTransformer
from docx import Document
import PyPDF2 # Use PyPDF2 instead of PyMuPDF
import requests
from bs4 import BeautifulSoup
from langdetect import detect, LangDetectException
# Initialize models and pipeline
qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased")
embedding_model = SentenceTransformer('distiluse-base-multilingual-cased-v1')
# FAISS index setup (in-memory)
dimension = 512 # Size of the embeddings
index = faiss.IndexFlatL2(dimension)
documents = []
# Initialize translation model for on-the-fly translation
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
def translate_text(text, src_lang, tgt_lang):
""" Translate text using the M2M100 model. """
tokenizer.src_lang = src_lang
encoded = tokenizer(text, return_tensors="pt")
generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(tgt_lang))
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# Sidebar for navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Upload Knowledge", "Q&A"])
# Page 1: Knowledge Upload
if page == "Upload Knowledge":
st.title("Upload Knowledge Base")
uploaded_files = st.file_uploader("Upload your files (DOCX, PDF)", type=["pdf", "docx"], accept_multiple_files=True)
url = st.text_input("Or enter a website URL to scrape")
if uploaded_files or url:
st.write("Processing your data...")
texts = []
# Process uploaded files
for file in uploaded_files:
try:
if file.type == "application/pdf":
pdf_reader = PyPDF2.PdfReader(file) # Use PyPDF2 for PDF reading
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
doc = Document(file)
text = " ".join([para.text for para in doc.paragraphs])
else:
st.error(f"Unsupported file type: {file.type}")
continue
# Language detection
try:
detected_lang = detect(text)
st.write(f"Detected language: {detected_lang}")
except LangDetectException:
st.error("Could not detect the language of the text.")
continue
# Generate embeddings
embedding = embedding_model.encode([text])[0]
# Add the embedding to FAISS index
index.add(np.array([embedding], dtype=np.float32))
documents.append(text)
texts.append(text)
except Exception as e:
st.error(f"Error processing file: {e}")
# Process URL
if url:
try:
response = requests.get(url)
soup = BeautifulSoup(response.text, 'html.parser')
text = soup.get_text()
try:
detected_lang = detect(text)
st.write(f"Detected language: {detected_lang}")
except LangDetectException:
st.error("Could not detect the language of the webpage.")
url = None # Set URL to None or skip to prevent further processing
if url: # Continue only if URL processing is valid
# Generate embedding
embedding = embedding_model.encode([text])[0]
# Add the embedding to FAISS index
index.add(np.array([embedding], dtype=np.float32))
documents.append(text)
texts.append(text)
except Exception as e:
st.error(f"Error processing URL: {e}")
st.write("Data processed and added to knowledge base!")
# Provide a summary of the uploaded content
for i, text in enumerate(texts):
st.write(f"Summary of Document {i+1}:")
st.write(text[:500] + "...") # Display first 500 characters as a summary
# Page 2: Q&A Interface
elif page == "Q&A":
st.title("Ask the Knowledge Base")
user_query = st.text_input("Enter your query:")
if user_query:
try:
detected_query_lang = detect(user_query)
# Translate the query if it's in a different language than the knowledge base
if detected_query_lang != "en":
st.write(f"Translating query from {detected_query_lang} to English")
user_query = translate_text(user_query, detected_query_lang, "en")
query_embedding = embedding_model.encode([user_query])
D, I = index.search(np.array(query_embedding, dtype=np.float32), k=5) # Retrieve top 5 documents
context = " ".join([documents[i] for i in I[0]])
# Pass translated query and context to the QA pipeline
result = qa_pipeline(question=user_query, context=context)
st.write(f"Answer: {result['answer']}")
except LangDetectException:
st.error("Could not detect the language of the query.")
except Exception as e:
st.error(f"Error during Q&A processing: {e}")