Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit_chat import message | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.memory import ConversationBufferMemory | |
from langchain_core.prompts import PromptTemplate | |
from langchain_community.vectorstores import FAISS | |
import pdfplumber | |
import docx2txt | |
from langchain_community.embeddings import OllamaEmbeddings | |
from langchain_groq import ChatGroq | |
from dotenv import load_dotenv | |
from easygoogletranslate import EasyGoogleTranslate | |
import os | |
import csv | |
import re | |
from io import StringIO | |
import speech_recognition as sr | |
import pygame | |
from threading import Thread | |
from gtts import gTTS | |
import gc | |
import torch | |
os.environ['CUDA_VISIBLE_DEVICES'] = '' | |
torch.set_num_threads(1) | |
load_dotenv() | |
groq_api_key = os.getenv('GROQ_API_KEY') | |
MAX_DOCUMENTS = 5 | |
def initialize_session_state(): | |
if 'history' not in st.session_state: | |
st.session_state['history'] = [] | |
if 'generated' not in st.session_state: | |
st.session_state['generated'] = ["Hello! Ask me anything about π€"] | |
if 'past' not in st.session_state: | |
st.session_state['past'] = ["Hey! π"] | |
if 'translated' not in st.session_state: | |
st.session_state['translated'] = ["Hello! Ask me anything about π€"] | |
if 'translation_requested' not in st.session_state: | |
st.session_state['translation_requested'] = [False] * len(st.session_state['generated']) | |
if 'chain' not in st.session_state: | |
st.session_state['chain'] = None | |
if 'vector_store' not in st.session_state: | |
st.session_state['vector_store'] = None | |
def translate_text(text, target_language='en'): | |
translator = EasyGoogleTranslate(target_language=target_language) | |
try: | |
return translator.translate(text) | |
except Exception as e: | |
st.error(f"Translation error: {e}") | |
return text | |
def clean_text_for_speech(text): | |
# Replacing symbols and formatting text | |
text = re.sub(r'[*_~#|β’ββ ββͺ]', '', text) | |
text = re.sub(r'\n', ' ', text) | |
text = re.sub(r'\s+', ' ', text) | |
text = re.sub(r'([.!?])\s*', r'\1 ', text) | |
text = re.sub(r'[:;]', ' ', text) | |
text = re.sub(r'[-]', ' ', text) | |
text = re.sub(r'[(){}\[\]]', '', text) | |
# Handle numbers and decimals | |
text = re.sub(r'(\d+)\.(\d+)', r'\1 point \2', text) | |
# Make sure to handle numbers correctly | |
replacements = { | |
'&': 'and', '%': 'percent', '$': 'dollars', 'β¬': 'euros', 'Β£': 'pounds', | |
'@': 'at', '#': 'hashtag', 'e.g.': 'for example', 'i.e.': 'that is', | |
'etc.': 'et cetera', 'vs.': 'versus', 'fig.': 'figure', 'approx.': 'approximately', | |
} | |
for key, value in replacements.items(): | |
text = text.replace(key, value) | |
return text.strip() | |
def text_to_speech(text, language='en', speed=1.0): | |
cleaned_text = clean_text_for_speech(text) | |
tts = gTTS(text=cleaned_text, lang=language, slow=(speed < 1.0)) | |
tts.save("output.mp3") | |
with open("output.mp3", "rb") as audio_file: | |
audio_bytes = audio_file.read() | |
return audio_bytes | |
def conversation_chat(query, chain, history): | |
template = """ | |
You are an expert analyst with deep knowledge across various fields. Your task is to provide an in-depth, comprehensive analysis of the uploaded documents. Approach each question with critical thinking and attention to detail. | |
You are only allowed to answer questions directly related to the content of the uploaded documents. | |
If a question is outside the scope of the documents, respond with: 'I'm sorry, I can only answer questions about the uploaded documents.' | |
Guidelines for Analysis: | |
1. Document Overview: | |
- Identify the type of document(s) (research paper, report, data set, etc.) | |
- Summarize the main topic and purpose of each document | |
2. Content Analysis: | |
- For research papers: Analyze the abstract, introduction, methodology, results, discussion, and conclusion | |
- For reports: Examine executive summary, key findings, and recommendations | |
- For data sets: Describe the structure, variables, and any apparent trends | |
3. Key Points and Findings: | |
- Highlight the most significant information and insights from each document | |
- Identify any unique or surprising elements in the content | |
4. Contextual Analysis: | |
- Place the information in a broader context within its field | |
- Discuss how this information relates to current trends or issues | |
5. Critical Evaluation: | |
- Assess the strengths and limitations of the presented information | |
- Identify any potential biases or gaps in the data or arguments | |
6. Implications and Applications: | |
- Discuss the potential impact of the findings or information | |
- Suggest possible applications or areas for further research | |
7. Comparative Analysis (if multiple documents): | |
- Compare and contrast information across different documents | |
- Identify any conflicting data or viewpoints | |
8. Data Interpretation: | |
- For numerical data: Provide clear explanations of statistics or trends | |
- For qualitative information: Offer interpretations of key quotes or concepts | |
9. Sourcing and Credibility: | |
- Comment on the credibility of the sources (if apparent) | |
- Note any references to other important works in the field | |
10. Comprehensive Response: | |
- Ensure all aspects of the question are addressed | |
- Provide a balanced view, considering multiple perspectives if applicable | |
Remember to maintain an objective, analytical tone. Your goal is to provide the most thorough and insightful analysis possible based on the available documents. | |
Previous Context: {previous_context} | |
Question: {query} | |
""" | |
prompt = PromptTemplate.from_template(template) | |
result = chain.invoke({"question": query, "chat_history": history}, prompt=prompt) | |
answer = result.get("answer", "I'm sorry, I couldn't generate an answer.") | |
history.append((query, answer)) | |
return answer | |
def display_chat_history(chain): | |
st.write("Chat History:") | |
for i in range(len(st.session_state['past'])): | |
message(st.session_state['past'][i], is_user=True, key=f'{i}_user', avatar_style="avataaars", seed="Aneka") | |
message(st.session_state['generated'][i], key=f'{i}_bot', avatar_style="bottts", seed="Aneka") | |
col1, col2, col3 = st.columns([2, 1, 1]) | |
with col1: | |
dest_language = st.selectbox('Select language for translation:', | |
options=['hi', 'kn'], | |
index=0, | |
key=f'{i}_lang_select') | |
with col2: | |
if st.button(f'Translate Message {i}', key=f'{i}_translate'): | |
translated_text = translate_text(st.session_state['generated'][i], target_language=dest_language) | |
st.session_state['translated'][i] = translated_text | |
st.session_state['translation_requested'][i] = True | |
st.experimental_rerun() | |
with col3: | |
if st.button(f'Play Message {i}', key=f'{i}_play'): | |
audio_bytes = text_to_speech(st.session_state['generated'][i]) | |
st.audio(audio_bytes, format="audio/mp3") | |
if st.session_state['translation_requested'][i]: | |
message(st.session_state['translated'][i], key=f'{i}_bot_translated', avatar_style="bottts", seed="Aneka") | |
if st.button(f'Play Translated Message {i}', key=f'{i}_play_translated'): | |
audio_bytes = text_to_speech(st.session_state['translated'][i], dest_language) | |
st.audio(audio_bytes, format="audio/mp3") | |
with st.form(key='user_input_form'): | |
user_input = st.text_input("Ask questions about your uploaded documents:", key="user_input") | |
submit_button = st.form_submit_button(label='Send') | |
if submit_button and user_input: | |
output = conversation_chat(user_input, chain, st.session_state['history']) | |
st.session_state['past'].append(user_input) | |
st.session_state['generated'].append(output) | |
st.session_state['translated'].append(output) | |
st.session_state['translation_requested'].append(False) | |
st.rerun() | |
def process_file(file): | |
if file.type == "application/pdf": | |
return process_pdf(file) | |
elif file.type == "text/plain": | |
return file.getvalue().decode("utf-8") | |
elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": | |
return docx2txt.process(file) | |
elif file.type == "text/csv": | |
return process_csv(file) | |
else: | |
st.error(f"Unsupported file type: {file.type}") | |
return "" | |
def process_csv(file): | |
text = "" | |
try: | |
file_content = file.getvalue().decode('utf-8') | |
csvfile = StringIO(file_content) | |
reader = csv.reader(csvfile) | |
headers = next(reader, None) | |
if headers: | |
text += f"CSV Headers: {', '.join(headers)}\n\n" | |
for i, row in enumerate(reader, 1): | |
text += f"Row {i}: {' | '.join(row)}\n" | |
text += f"\nTotal rows: {i}\n" | |
except Exception as e: | |
st.error(f"Error reading CSV file: {e}") | |
return text | |
def process_pdf(file): | |
text = "" | |
with pdfplumber.open(file) as pdf: | |
for page_num, page in enumerate(pdf.pages, 1): | |
page_text = page.extract_text() | |
if page_text: | |
text += f"[Page {page_num}]\n{page_text}\n\n" | |
sections = re.findall(r'(?i)(abstract|introduction|methodology|results|discussion|conclusion).*?\n(.*?)(?=\n(?i)(abstract|introduction|methodology|results|discussion|conclusion)|$)', text, re.DOTALL) | |
structured_text = "\n\n".join([f"{section[0].capitalize()}:\n{section[1]}" for section in sections]) | |
return structured_text if structured_text else text | |
def recognize_speech(): | |
recognizer = sr.Recognizer() | |
with sr.Microphone() as source: | |
st.write("Listening... Please speak now.") | |
try: | |
st.info("Listening for up to 10 seconds...") | |
recognizer.adjust_for_ambient_noise(source, duration=1) | |
audio = recognizer.listen(source, timeout=10, phrase_time_limit=5) | |
st.success("Audio captured. Processing...") | |
except sr.WaitTimeoutError: | |
st.warning("No speech detected. Please try again.") | |
return "" | |
try: | |
text = recognizer.recognize_google(audio) | |
st.success(f"You said: {text}") | |
return text | |
except sr.UnknownValueError: | |
st.error("Sorry, I couldn't understand that.") | |
return "" | |
except sr.RequestError as e: | |
st.error(f"Could not request results; {e}") | |
return "" | |
def create_conversational_chain(vector_store): | |
llm = ChatGroq(groq_api_key=groq_api_key, model_name='llama3-70b-8192') | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
chain_type='stuff', | |
retriever=vector_store.as_retriever(search_kwargs={"k": 5}), | |
memory=memory | |
) | |
return chain | |
def main(): | |
initialize_session_state() | |
st.set_page_config(page_title="DOCS Chatbot & Translator", layout="wide") | |
st.title("Smart Document Tool π€") | |
st.sidebar.header("About App:") | |
st.sidebar.write("This app utilizes Streamlit") | |
uploaded_files = st.file_uploader("Upload your Docs", type=["pdf", "txt", "docx", "csv"], accept_multiple_files=True) | |
if uploaded_files: | |
all_text = "" | |
for uploaded_file in uploaded_files[:MAX_DOCUMENTS]: | |
try: | |
all_text += f"File: {uploaded_file.name}\n\n{process_file(uploaded_file)}\n\n" | |
except Exception as e: | |
st.error(f"Error processing file {uploaded_file.name}: {e}") | |
finally: | |
gc.collect() | |
if len(uploaded_files) > MAX_DOCUMENTS: | |
st.warning(f"Only the first {MAX_DOCUMENTS} documents were processed due to memory constraints.") | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=4000, | |
chunk_overlap=300, | |
length_function=len, | |
separators=["\n\n", "\n", " ", ""] | |
) | |
text_chunks = text_splitter.split_text(all_text) | |
embeddings = OllamaEmbeddings(model="nomic-embed-text") | |
with st.spinner('Analyzing Document...'): | |
vector_store = FAISS.from_texts(text_chunks, embedding=embeddings) | |
st.session_state['vector_store'] = vector_store | |
st.session_state['chain'] = create_conversational_chain(vector_store) | |
display_chat_history(st.session_state['chain']) | |
if st.button('Speak Now'): | |
recognized_text = recognize_speech() | |
if recognized_text: | |
st.session_state['past'].append(recognized_text) | |
output = conversation_chat(recognized_text, st.session_state['chain'], st.session_state['history']) | |
st.session_state['generated'].append(output) | |
st.session_state['translated'].append(output) | |
st.session_state['translation_requested'].append(False) | |
audio_bytes = text_to_speech(output) | |
st.audio(audio_bytes, format="audio/mp3") | |
st.rerun() | |
else: | |
st.warning("No speech input was processed. Please try speaking again.") | |
gc.collect() | |
st.sidebar. caption="Your AI Assistant" | |
if __name__ == "__main__": | |
main() | |