Spaces:
Sleeping
Sleeping
import google.generativeai as genai | |
from pinecone import Pinecone, ServerlessSpec | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_pinecone import PineconeVectorStore | |
from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
from langchain_core.documents import Document | |
import io | |
import PyPDF2 | |
import pandas as pd | |
import logging | |
import asyncio | |
from dotenv import load_dotenv | |
import os | |
import uuid | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
# Configure Gemini API | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
genai.configure(api_key=GEMINI_API_KEY) | |
# Initialize Pinecone | |
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") | |
pc = Pinecone(api_key=PINECONE_API_KEY) | |
cloud = os.environ.get('PINECONE_CLOUD', 'aws') | |
region = os.environ.get('PINECONE_REGION', 'us-east-1') | |
spec = ServerlessSpec(cloud=cloud, region=region) | |
# Define index name and embedding dimension | |
index_name = "rag-donor-index" | |
embedding_dimension = 768 # For text-embedding-004 | |
# Check if index exists, create if not | |
if index_name not in pc.list_indexes().names(): | |
logger.info(f"Creating Pinecone index: {index_name}") | |
pc.create_index( | |
name=index_name, | |
dimension=embedding_dimension, | |
metric="cosine", | |
spec=spec | |
) | |
# Wait for index to be ready | |
while not pc.describe_index(index_name).status['ready']: | |
asyncio.sleep(1) | |
logger.info(f"Pinecone index {index_name} is ready.") | |
# Initialize embeddings | |
embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004", google_api_key=GEMINI_API_KEY) | |
# Function to process uploaded file (PDF, text, CSV, or XLSX) without saving locally | |
def process_uploaded_file(file_stream, filename): | |
logger.info(f"Processing uploaded file: {filename}") | |
try: | |
if filename.lower().endswith('.pdf'): | |
logger.info("Processing as PDF file.") | |
pdf_reader = PyPDF2.PdfReader(file_stream) | |
text = "" | |
for page in pdf_reader.pages: | |
text += page.extract_text() or "" | |
# Split PDF content into chunks | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=100 | |
) | |
chunks = text_splitter.split_text(text) | |
documents = [Document(page_content=chunk, metadata={"source": filename, "chunk_id": str(uuid.uuid4())}) for chunk in chunks] | |
logger.info(f"Extracted {len(documents)} chunks from PDF.") | |
return documents | |
elif filename.lower().endswith(('.txt', '.md')): | |
logger.info("Processing as text file.") | |
content = file_stream.read().decode('utf-8', errors='replace') | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=100 | |
) | |
chunks = text_splitter.split_text(content) | |
documents = [Document(page_content=chunk, metadata={"source": filename, "chunk_id": str(uuid.uuid4())}) for chunk in chunks] | |
logger.info(f"Extracted {len(documents)} chunks from text file.") | |
return documents | |
elif filename.lower().endswith('.csv'): | |
logger.info("Processing as CSV file.") | |
df = pd.read_csv(file_stream) | |
# Convert DataFrame to string representation | |
text = df.to_string() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=100 | |
) | |
chunks = text_splitter.split_text(text) | |
documents = [Document(page_content=chunk, metadata={"source": filename, "chunk_id": str(uuid.uuid4())}) for chunk in chunks] | |
logger.info(f"Extracted {len(documents)} chunks from CSV.") | |
return documents | |
elif filename.lower().endswith('.xlsx'): | |
logger.info("Processing as XLSX file.") | |
df = pd.read_excel(file_stream, engine='openpyxl') | |
# Convert DataFrame to string representation | |
text = df.to_string() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=100 | |
) | |
chunks = text_splitter.split_text(text) | |
documents = [Document(page_content=chunk, metadata={"source": filename, "chunk_id": str(uuid.uuid4())}) for chunk in chunks] | |
logger.info(f"Extracted {len(documents)} chunks from XLSX.") | |
return documents | |
else: | |
raise ValueError("Unsupported file type. Only PDF, text, CSV, and XLSX files are supported.") | |
except Exception as e: | |
logger.error(f"Error processing file {filename}: {str(e)}") | |
raise Exception(f"Error processing file: {str(e)}") | |
# Function to index documents in Pinecone | |
def index_documents(documents, namespace="chatbot-knowledge", batch_size=50): | |
logger.info(f"Indexing {len(documents)} documents in Pinecone.") | |
try: | |
vector_store = PineconeVectorStore( | |
index_name=index_name, | |
embedding=embeddings, | |
namespace=namespace | |
) | |
# Batch documents to avoid Pinecone size limits | |
for i in range(0, len(documents), batch_size): | |
batch = documents[i:i + batch_size] | |
batch_size_bytes = sum(len(doc.page_content.encode('utf-8')) for doc in batch) | |
if batch_size_bytes > 4_000_000: | |
logger.warning(f"Batch size {batch_size_bytes} bytes exceeds Pinecone limit. Reducing batch size.") | |
smaller_batch_size = batch_size // 2 | |
for j in range(0, len(batch), smaller_batch_size): | |
smaller_batch = batch[j:j + smaller_batch_size] | |
vector_store.add_documents(smaller_batch) | |
logger.info(f"Indexed batch {j // smaller_batch_size + 1} of {len(batch) // smaller_batch_size + 1}") | |
else: | |
vector_store.add_documents(batch) | |
logger.info(f"Indexed batch {i // batch_size + 1} of {len(documents) // batch_size + 1}") | |
logger.info("Document indexing completed.") | |
return vector_store | |
except Exception as e: | |
logger.error(f"Error indexing documents: {e}") | |
raise Exception(f"Error indexing documents: {e}") | |
# RAG chatbot function | |
def rag_chatbot(query, namespace="chatbot-knowledge"): | |
logger.info(f"Processing query: {query}") | |
try: | |
# Initialize vector store | |
vector_store = PineconeVectorStore( | |
index_name=index_name, | |
embedding=embeddings, | |
namespace=namespace | |
) | |
# Retrieve relevant documents | |
relevant_docs_with_scores = vector_store.similarity_search_with_score(query, k=3) | |
for doc, score in relevant_docs_with_scores: | |
logger.info(f"Score: {score:.4f} | Document: {doc.page_content}") | |
# Combine context from retrieved documents | |
context = "\n".join([doc.page_content for doc, score in relevant_docs_with_scores]) | |
# Create prompt for Gemini | |
prompt = f"""You are a helpful chatbot that answers questions based on provided context. | |
Context: | |
{context} | |
User Query: {query} | |
Provide a concise and accurate answer based on the context. If the context doesn't contain relevant information, say so and provide a general response if applicable. | |
""" | |
# Initialize Gemini model | |
model = genai.GenerativeModel('gemini-1.5-flash') | |
# Generate response | |
response = model.generate_content(prompt) | |
logger.info("Generated response successfully.") | |
return response.text | |
except Exception as e: | |
logger.error(f"Error processing query: {e}") | |
return f"Error processing query: {e}" |