Spaces:
Running
Running
import os | |
import time | |
import json | |
import logging | |
import threading | |
import gradio as gr | |
import google.generativeai as genai | |
from googleapiclient.discovery import build | |
from googleapiclient.http import MediaIoBaseDownload | |
from google.oauth2 import service_account | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader | |
from langchain.chains import RetrievalQA | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from PyPDF2 import PdfReader | |
from gtts import gTTS | |
from sentence_transformers import SentenceTransformer | |
import concurrent.futures | |
# ✅ Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
temp_file_map = {} | |
logging.info("🔑 Loading API keys...") | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY_1") | |
SERVICE_ACCOUNT_JSON = os.getenv("SERVICE_ACCOUNT_JSON") | |
if not GOOGLE_API_KEY or not SERVICE_ACCOUNT_JSON: | |
logging.error("❌ Missing API Key or Service Account JSON.") | |
raise ValueError("❌ Missing API Key or Service Account JSON. Please add them as environment variables.") | |
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY | |
SERVICE_ACCOUNT_FILE = json.loads(SERVICE_ACCOUNT_JSON) | |
SCOPES = ["https://www.googleapis.com/auth/drive"] | |
FOLDER_ID = "1xqOpwgwUoiJYf9GkeuB4dayme4zJcujf" | |
creds = service_account.Credentials.from_service_account_info(SERVICE_ACCOUNT_FILE) | |
drive_service = build("drive", "v3", credentials=creds) | |
vector_store = None | |
file_id_map = {} | |
temp_dir = "./temp_downloads" | |
os.makedirs(temp_dir, exist_ok=True) | |
def get_files_from_drive(): | |
logging.info("📂 Fetching files from Google Drive...") | |
query = f"'{FOLDER_ID}' in parents and trashed = false" | |
results = drive_service.files().list(q=query, fields="files(id, name)").execute() | |
files = results.get("files", []) | |
global file_id_map | |
file_id_map = {file["name"]: file["id"] for file in files} | |
return list(file_id_map.keys()) if files else [] | |
def download_file(file_id, file_name): | |
logging.info(f"📥 Downloading file: {file_name}") | |
file_path = os.path.join(temp_dir, file_name) | |
request = drive_service.files().get_media(fileId=file_id) | |
with open(file_path, "wb") as f: | |
downloader = MediaIoBaseDownload(f, request) | |
done = False | |
while not done: | |
_, done = downloader.next_chunk() | |
return file_path | |
def load_document(file_name, file_path): | |
try: | |
if file_name.endswith(".pdf"): | |
return PyPDFLoader(file_path).load() | |
elif file_name.endswith(".txt"): | |
return TextLoader(file_path).load() | |
elif file_name.endswith(".docx"): | |
return Docx2txtLoader(file_path).load() | |
else: | |
logging.warning(f"⚠️ Unsupported file type: {file_name}") | |
return [] | |
except Exception as e: | |
logging.error(f"❌ Error loading {file_name}: {e}") | |
return [] | |
def process_documents(selected_files): | |
global vector_store | |
# ✅ Clear the existing vector store before processing new documents | |
if vector_store is not None: | |
logging.info("🗑️ Clearing previous document embeddings...") | |
vector_store.delete_collection() # Clears existing stored data | |
docs = [] | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future_to_file = { | |
executor.submit(load_document, file_name, download_file(file_id_map[file_name], file_name)): file_name | |
for file_name in selected_files | |
} | |
for future in concurrent.futures.as_completed(future_to_file): | |
docs.extend(future.result()) | |
total_words = sum(len(doc.page_content.split()) for doc in docs) | |
if total_words < 1000: | |
chunk_size, chunk_overlap, file_size_category = 500, 50, "small" | |
elif total_words < 5000: | |
chunk_size, chunk_overlap, file_size_category = 1000, 100, "medium" | |
else: | |
chunk_size, chunk_overlap, file_size_category = 2000, 200, "large" | |
logging.info(f"📄 Document Size: {total_words} words | Category: {file_size_category} | Chunk Size: {chunk_size}") | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
split_docs = text_splitter.split_documents(docs) | |
embedding_model = ( | |
"sentence-transformers/all-MiniLM-L6-v2" if file_size_category == "small" else "sentence-transformers/paraphrase-MiniLM-L3-v2" | |
) | |
logging.info(f"🧠 Using Transformer Model: {embedding_model}") | |
embeddings = HuggingFaceEmbeddings(model_name=embedding_model) | |
# ✅ Create a new Chroma vector store for new documents | |
vector_store = Chroma.from_documents(split_docs, embeddings) | |
return "✅ Documents processed successfully!" | |
def query_document(question): | |
if vector_store is None: | |
return "❌ No documents processed.", None | |
# ✅ Fetch stored documents | |
stored_docs = vector_store.get()["documents"] | |
# ✅ Calculate total word count safely | |
total_words = sum(len(doc.split()) if isinstance(doc, str) else len(doc.page_content.split()) for doc in stored_docs) | |
# ✅ Categorize file size and set retrieval depth | |
if total_words < 500: | |
file_size_category = "small" | |
k_value = 3 | |
prompt_prefix = "Provide a **concise** response focusing on key points." | |
elif total_words < 2000: | |
file_size_category = "medium" | |
k_value = 5 | |
prompt_prefix = "Provide a **detailed response** with examples and key insights." | |
else: | |
file_size_category = "large" | |
k_value = 10 | |
prompt_prefix = "Provide a **comprehensive and structured response**, including step-by-step analysis and explanations." | |
logging.info(f"🔎 Querying Vector Store | File Size: {file_size_category} | Search Depth: {k_value}") | |
# ✅ Setup retriever | |
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": k_value}) | |
# ✅ Dynamically select model based on file size | |
if file_size_category in ["small", "medium"]: | |
model_name = "gemini-2.0-pro-exp-02-05" | |
else: | |
model_name = "gemini-2.0-flash" | |
logging.info(f"🤖 Using LLM Model: {model_name}") | |
# ✅ Create detailed prompt | |
detailed_prompt = f"""{prompt_prefix} | |
- Ensure clarity and completeness. | |
- Highlight the most relevant information. | |
**Question:** {question} | |
""" | |
# ✅ Invoke LLM model | |
model = ChatGoogleGenerativeAI(model=model_name, google_api_key=GOOGLE_API_KEY) | |
qa_chain = RetrievalQA.from_chain_type(llm=model, retriever=retriever) | |
response = qa_chain.invoke({"query": detailed_prompt})["result"] | |
logging.info(f"📝 Bot Output: {response[:200]}...") # Log only first 200 chars for readability | |
# ✅ Convert response to speech | |
tts = gTTS(text=response, lang="en") | |
temp_audio_path = os.path.join(temp_dir, "response.mp3") | |
tts.save(temp_audio_path) | |
temp_file_map["response.mp3"] = time.time() | |
return response, temp_audio_path | |
# ✅ Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# 📄 AI-Powered Multi-Document Chatbot with Voice Output") | |
file_dropdown = gr.Dropdown(choices=get_files_from_drive(), label="📂 Select Files", multiselect=True) | |
refresh_button = gr.Button("🔄 Refresh Files") # 🔄 Add Refresh Button | |
process_button = gr.Button("🚀 Process Documents") | |
user_input = gr.Textbox(label="🔎 Ask a Question") | |
submit_button = gr.Button("💬 Get Answer") | |
response_output = gr.Textbox(label="📝 Response") | |
audio_output = gr.Audio(label="🔊 Audio Response") | |
# 🔄 Function to Refresh File List | |
def refresh_files(): | |
return gr.update(choices=get_files_from_drive()) | |
# ✅ Connect Refresh Button | |
refresh_button.click(refresh_files, outputs=file_dropdown) | |
# ✅ Connect Process Button | |
process_button.click(process_documents, inputs=file_dropdown, outputs=response_output) | |
# ✅ Connect Query Button | |
submit_button.click(query_document, inputs=user_input, outputs=[response_output, audio_output]) | |
demo.launch() | |