Opexa / chatbot.py
JBigger's picture
Upload 6 files
2a502d8 verified
import os
import google.generativeai as genai
from dotenv import load_dotenv
from pathlib import Path
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders.pdf import PyPDFLoader
from langchain_community.document_loaders.directory import DirectoryLoader
from langchain_core.documents import Document
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from tqdm import tqdm
import json
from datetime import datetime
import itertools # Added for groupby functionality
# Load environment variables
load_dotenv()
# Initialize Gemini API
genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
model = genai.GenerativeModel('gemini-1.5-pro') # Using stable version instead of preview
# Initialize the sentence transformer model for embeddings
embedder = SentenceTransformer('all-mpnet-base-v2')
class LearningChatbot:
def __init__(self, docs_path="./documents"):
"""Initialize chatbot with document path"""
self.docs_path = docs_path
self.vector_store = None
self.documents = []
self.initialize_knowledge_base()
def _load_json_file(self, file_path):
"""Load and process JSON file into document chunks"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Convert JSON to text chunks
chunks = []
def process_json(obj, parent_key=''):
if isinstance(obj, dict):
for key, value in obj.items():
new_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, (dict, list)):
process_json(value, new_key)
else:
chunks.append(f"{new_key}: {value}")
elif isinstance(obj, list):
for i, item in enumerate(obj):
new_key = f"{parent_key}[{i}]"
if isinstance(item, (dict, list)):
process_json(item, new_key)
else:
chunks.append(f"{new_key}: {item}")
process_json(data)
return chunks
except Exception as e:
print(f"Error loading JSON file {file_path}: {str(e)}")
return []
def initialize_knowledge_base(self):
"""Load and process documents into vector store with memory management"""
try:
print("Loading documents...")
self.documents = []
# Process files in batches
batch_size = 5
all_files = list(Path(self.docs_path).glob("**/*.*"))
for i in range(0, len(all_files), batch_size):
batch_files = all_files[i:i + batch_size]
batch_docs = []
for file in batch_files:
try:
if file.suffix.lower() == '.pdf':
loader = PyPDFLoader(str(file))
batch_docs.extend(loader.load())
elif file.suffix.lower() == '.json':
chunks = self._load_json_file(str(file))
# Convert chunks to document format
batch_docs.extend([
Document(page_content=chunk, metadata={"source": str(file)})
for chunk in chunks
])
except Exception as e:
print(f"Error loading {file}: {str(e)}")
continue
self.documents.extend(batch_docs)
# Clear memory after each batch
batch_docs = None
print(f"Loaded {len(self.documents)} documents")
# Memory-efficient text splitting
print("Splitting text...")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=256, # Reduced chunk size
chunk_overlap=20, # Reduced overlap
separators=["\n\n", "\n", ".", "!", "?", ";", ",", " "],
length_function=len,
)
# Split documents in batches
processed_chunks = []
batch_size = 50 # Process 50 chunks at a time
for i in range(0, len(self.documents), batch_size):
batch = self.documents[i:i + batch_size]
chunks = text_splitter.split_documents(batch)
processed_chunks.extend(chunks)
# Clear batch from memory
batch = None
self.documents = processed_chunks
print(f"Created {len(self.documents)} chunks")
# Generate embeddings in batches
print("Generating embeddings...")
embeddings = []
batch_size = 32 # Process 32 embeddings at a time
for i in range(0, len(self.documents), batch_size):
batch = self.documents[i:i + batch_size]
texts = [doc.page_content for doc in batch]
# Generate embeddings for batch
batch_embeddings = embedder.encode(texts)
embeddings.extend(batch_embeddings)
# Clear batch from memory
batch = None
texts = None
batch_embeddings = None
# Initialize FAISS index with memory-efficient approach
print("Building search index...")
dimension = embeddings[0].shape[0]
self.vector_store = faiss.IndexFlatL2(dimension)
# Add embeddings in batches
batch_size = 1000 # Add 1000 vectors at a time
embeddings_array = np.array(embeddings)
for i in range(0, len(embeddings_array), batch_size):
batch = embeddings_array[i:i + batch_size]
self.vector_store.add(batch)
# Clear batch from memory
batch = None
# Clear large objects from memory
embeddings = None
embeddings_array = None
print("Knowledge base initialization complete")
except Exception as e:
print(f"Error initializing knowledge base: {str(e)}")
raise e
def verify_knowledge_base(self):
"""
Verify if the knowledge base is properly initialized
Returns:
bool: True if vector store and documents are ready
"""
try:
return (
self.vector_store is not None and
len(self.documents) > 0 and
hasattr(self.vector_store, 'ntotal') and
self.vector_store.ntotal > 0
)
except Exception as e:
print(f"Error verifying knowledge base: {str(e)}")
return False
def get_relevant_context(self, query, k=3):
"""Memory-efficient context retrieval"""
try:
# Generate query embedding
query_vector = embedder.encode([query])[0]
# Search in batches if there are many documents
batch_size = 1000
if self.vector_store.ntotal > batch_size:
distances = []
indices = []
for i in range(0, self.vector_store.ntotal, batch_size):
end_idx = min(i + batch_size, self.vector_store.ntotal)
batch_distances, batch_indices = self.vector_store.search(
query_vector.reshape(1, -1),
min(k, end_idx - i)
)
distances.extend(batch_distances[0])
indices.extend(batch_indices[0])
# Get top k results
top_indices = sorted(range(len(distances)), key=lambda i: distances[i])[:k]
relevant_docs = [self.documents[indices[i]].page_content for i in top_indices]
else:
# For smaller document sets, search all at once
distances, indices = self.vector_store.search(query_vector.reshape(1, -1), k)
relevant_docs = [self.documents[i].page_content for i in indices[0]]
return "\n".join(relevant_docs)
except Exception as e:
print(f"Error retrieving context: {str(e)}")
return ""
def _construct_educational_prompt(self, query, context):
"""
Construct an OpexA-focused prompt that delivers clear, concise, and actionable responses
"""
base_prompt = f"""You are an expert assistant for OpexA, an EdTech platform focused on career growth for IT professionals,
businesses, and public sector users. Your goal is to deliver clear, concise, and actionable answers while maintaining a friendly
and supportive tone.
Context from OpexA materials:
{context}
User Question: {query}
Key Guidelines for Your Response:
1. User Segments - Tailor your response based on user type:
β€’ Beginners: Offer foundational guidance and basic concepts
β€’ Career Changers: Focus on transition plans and skill mapping
β€’ Experienced Professionals: Provide advanced insights and industry-specific details
β€’ Business/Public Sector: Address organizational needs and compliance
Response Structure:
1. Start with direct, relevant information
2. Use bullet points for lists and steps
3. Include practical examples or analogies
4. Add proactive tips or next steps
5. End with an engaging question
Handling Special Cases:
β€’ Unclear Questions: Ask for clarification (e.g., "Are you interested in career assessments or account settings?")
β€’ Out-of-Scope: Politely redirect to available features
β€’ Privacy Concerns: Provide reassurance about data protection
Style Guidelines:
β€’ Use natural, conversational language
β€’ Include relevant emojis sparingly (πŸš€ for growth, πŸ”’ for security)
β€’ Format lists and steps with bullet points (β€’)
β€’ Keep responses concise but informative
β€’ End with engaging questions like "What's your next goal?" or "Ready to explore more?"
Now, please provide a helpful response to: {query}"""
return base_prompt
def _format_response(self, response):
"""Format the response with consistent list formatting and proper line breaks"""
try:
text = response.text
# Split into paragraphs
paragraphs = text.split('\n\n')
formatted_paragraphs = []
for p in paragraphs:
lines = p.split('\n')
formatted_lines = []
in_list = False
previous_was_list = False
for line in lines:
line = line.strip()
# Check if this is a list item
is_list_item = line.startswith(('β€’', '-', '*', 'β—‹', 'Β·', 'β–Ί', 'β†’', '1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.'))
# Add extra line break before list items (except for the first one)
if is_list_item and previous_was_list:
formatted_lines.append('') # Add empty line between list items
if is_list_item:
# Standardize bullet points
if line[0].isdigit(): # If it's a numbered list
line = 'β€’ ' + line[line.find(' ')+1:].strip()
else: # If it's already a bullet point
line = 'β€’ ' + line[1:].strip()
in_list = True
previous_was_list = True
else:
# If this looks like it should be a list item but missing bullet
if in_list and line and not line.endswith(':'):
if previous_was_list:
formatted_lines.append('') # Add empty line between list items
line = 'β€’ ' + line
previous_was_list = True
else:
in_list = False
previous_was_list = False
formatted_lines.append(line)
# Join lines with appropriate spacing
formatted_text = '\n'.join(formatted_lines)
# Add extra newline before lists for better readability
if any(line.startswith('β€’ ') for line in formatted_lines):
formatted_text = '\n' + formatted_text
formatted_paragraphs.append(formatted_text)
# Join paragraphs with double newlines
formatted_text = '\n\n'.join(formatted_paragraphs)
# Clean up multiple consecutive newlines
formatted_text = '\n'.join(line for line, _ in itertools.groupby(formatted_text.splitlines()))
# If response is too long, keep main points while preserving list structure
if len(formatted_text) > 500:
main_paragraphs = []
# Always keep the first paragraph (usually the main explanation)
main_paragraphs.append(formatted_paragraphs[0])
# Keep all bullet point lists
for p in formatted_paragraphs[1:]:
if 'β€’ ' in p:
main_paragraphs.append(p)
formatted_text = '\n\n'.join(main_paragraphs)
return formatted_text.strip()
except Exception as e:
return f"I apologize, but I ran into an issue formatting the response. Let me try to help you in a simpler way: {str(e)}"
def _handle_generation_error(self, error):
"""Handle errors with a natural, supportive tone"""
return f"""I apologize, but I'm having trouble helping you at the moment.
This might be because:
- I'm still processing some information
- There might be a technical issue
- The question might need to be more specific
Would you mind trying to rephrase your question? I want to make sure I give you the best help possible.
Technical note: {str(error)}"""
def generate_response(self, query):
"""Generate natural, personalized responses for students"""
try:
if not self.verify_knowledge_base():
return """I'm having trouble accessing our learning materials at the moment.
Could you make sure all the study materials are properly loaded?
This helps me give you the most accurate and helpful responses."""
# Get relevant context
context = self.get_relevant_context(query, k=3)
# Construct educational prompt
prompt = self._construct_educational_prompt(query, context)
# Generate response with simplified parameters
response = model.generate_content(prompt)
# Return natural response
return self._format_response(response)
except Exception as e:
return self._handle_generation_error(e)