chatbot / app.py
SuriRaja's picture
Update app.py
1b13842 verified
import gradio as gr
import logging
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
import faiss
from simple_salesforce import Salesforce
from dotenv import load_dotenv
import zipfile
from pathlib import Path
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables from .env file
load_dotenv() # Load the .env file
# Get the Salesforce credentials from environment variables
sf_username = os.getenv("SF_USERNAME")
sf_password = os.getenv("SF_PASSWORD")
sf_security_token = os.getenv("SF_SECURITY_TOKEN")
sf_instance_url = os.getenv("SF_INSTANCE_URL")
# Check if the environment variables are correctly set
if not sf_username or not sf_password or not sf_security_token or not sf_instance_url:
logger.error("❌ Salesforce credentials are missing from environment variables!")
raise ValueError("Salesforce credentials are not properly set.")
# Salesforce connection
try:
sf = Salesforce(
username=sf_username,
password=sf_password,
security_token=sf_security_token,
instance_url=sf_instance_url
)
logger.info("βœ… Connected to Salesforce")
except Exception as e:
logger.error(f"❌ Salesforce connection failed: {str(e)}")
raise
# --- Extract zip files and read documents ---
def extract_zip(zip_path, extract_to):
try:
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
logger.info(f"Extracted {zip_path} to {extract_to}")
except Exception as e:
logger.error(f"Failed to extract {zip_path}: {str(e)}")
raise
def load_documents(folder_path):
documents = []
sources = []
for file in Path(folder_path).rglob("*.txt"):
text = file.read_text(encoding="utf-8", errors="ignore")
documents.append(text)
sources.append(file.name)
return documents, sources
# --- Chunking ---
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
# --- Load model ---
model = SentenceTransformer("all-MiniLM-L6-v2")
# --- Preprocessing ---
data_dir = Path("./data")
data_dir.mkdir(exist_ok=True)
doc_folders = [
("Company_Policies.zip", "Company_Policies"),
("HR_Policies.zip", "Hr_Policies"),
("Contract_Clauses.zip", "Contract_Clauses")
]
all_chunks = []
metadata = []
for zip_name, folder in doc_folders:
zip_path = Path(zip_name)
if not zip_path.exists():
logger.error(f"Zip file {zip_name} not found")
raise FileNotFoundError(f"Zip file {zip_name} not found")
extract_path = data_dir / folder
extract_path.mkdir(exist_ok=True)
extract_zip(zip_path, extract_path)
docs, sources = load_documents(extract_path)
if not docs:
logger.error(f"No documents found in {extract_path}")
raise ValueError(f"No documents found in {extract_path}")
for doc, src in zip(docs, sources):
chunks = text_splitter.split_text(doc)
all_chunks.extend(chunks)
src_url = f"https://company.com/{folder}/{src}"
metadata.extend([src_url] * len(chunks))
# --- Embeddings + FAISS index ---
embeddings = model.encode(all_chunks)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(np.array(embeddings))
logger.info("FAISS index built successfully")
# --- Create Record in Salesforce ---
def create_salesforce_record(query, answer, confidence_percentage, source_link):
try:
# Convert the confidence_percentage to Python float (to avoid numpy float32)
confidence_percentage = float(confidence_percentage)
# Data with correctly mapped field names
data = {
"Query__c": query, # Field for User Query
"Answer__c": answer, # Field for Answer
"Confidence_Percentage__c": confidence_percentage, # Field for Confidence Score
"Document_link__c": source_link, # Field for Document Link
}
# Creating the record in Salesforce
response = sf.chat_query_log__c.create(data)
# Check if record was created successfully
if 'id' in response: # If the response contains an 'id', the record is created successfully
record_id = response['id']
logger.info(f"βœ… Record created successfully in Salesforce with ID: {record_id}")
return record_id # Return the Salesforce record ID
else:
# Log the failure response
logger.error(f"❌ Failed to create Salesforce record. Response: {response}")
return None
except Exception as e:
# Log any error during record creation
logger.error(f"Error creating Salesforce record: {str(e)}")
return None
# --- Search & Answer ---
def answer_query(query):
try:
logger.info(f"Processing query: {query}")
query_embedding = model.encode([query])
D, I = index.search(np.array(query_embedding), k=3)
top_chunks = [all_chunks[i] for i in I[0]]
top_sources = [metadata[i] for i in I[0]]
distances = D[0]
relevant_chunks = [
chunk for chunk, dist in zip(top_chunks, distances) if dist < 0.8
]
relevant_sources = [
src for src, dist in zip(top_sources, distances) if dist < 0.8
]
if not relevant_chunks:
return "No relevant information found.", "Confidence: 0%", "Source Link: None"
answer = relevant_chunks[0].strip()
min_distance = min(distances)
confidence_percentage = max(0, 100 - (min_distance * 100))
source_link = relevant_sources[0] if relevant_sources else "None"
# Create Salesforce record for the query response
record_id = create_salesforce_record(query, answer, confidence_percentage, source_link)
if record_id:
return (
answer,
f"Confidence: {confidence_percentage:.2f}%",
f"Source Link: {source_link}",
f"Salesforce Record ID: {record_id}" # Display the Salesforce record ID
)
else:
return (
answer,
f"Confidence: {confidence_percentage:.2f}%",
f"Source Link: {source_link}",
"Failed to create record in Salesforce"
)
except Exception as e:
logger.error(f"Error in answer_query: {str(e)}")
return f"Error: {str(e)}", "", "", ""
# --- Gradio Chatbot UI Design ---
def process_question(q, chat_history):
if not q.strip():
return chat_history + [("User", "Please enter a question.")], "", ""
answer, confidence, source, record_id = answer_query(q)
chat_history.append(("User", q))
chat_history.append(("Bot", answer))
return chat_history, confidence, source, record_id
# --- Chatbot UI with dynamic styling using elem_id ---
with gr.Blocks(title="Company Documents Q&A Chatbot", theme=gr.themes.Soft()) as demo:
gr.Markdown("## πŸ“š **Company Policies Q&A Chatbot**")
with gr.Row():
with gr.Column(scale=3):
question = gr.Textbox(
label="Ask a Question",
placeholder="What are the conditions for permanent employment status?",
lines=1,
interactive=True,
elem_id="user-question",
)
with gr.Column(scale=1):
submit_btn = gr.Button("Submit", variant="primary", elem_id="submit-btn", scale=2) # Using scale for full-width
with gr.Row():
with gr.Column():
chat_history = gr.Chatbot(
label="Chat History",
elem_id="chatbox",
height=400, # Set a fixed height
show_label=False # Hide the label to make the chat more clean
)
conf_out = gr.Markdown(label="Confidence", elem_id="confidence")
source_out = gr.Markdown(label="Source Link", elem_id="source-link")
record_out = gr.Markdown(label="Salesforce Record ID", elem_id="salesforce-id")
submit_btn.click(fn=process_question, inputs=[question, chat_history], outputs=[chat_history, conf_out, source_out, record_out])
# --- CSS for VFX Styles ---
demo.css = """
/* Chatbot Container */
#chatbox {
background-color: #f9f9f9;
border-radius: 12px;
box-shadow: 0 4px 10px rgba(0, 0, 0, 0.1);
padding: 15px;
overflow-y: auto;
}
/* User and Bot message bubbles */
.gradio-chatbot-message-user {
background-color: #0d6efd;
color: white;
border-radius: 15px;
padding: 10px;
margin: 5px 0;
animation: fadeIn 0.5s ease-in-out;
}
.gradio-chatbot-message-bot {
background-color: #f1f1f1;
color: #333;
border-radius: 15px;
padding: 10px;
margin: 5px 0;
animation: fadeIn 0.5s ease-in-out;
}
/* Input Box */
#user-question {
background-color: #e9ecef;
border-radius: 8px;
padding: 10px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
transition: background-color 0.3s ease;
}
#user-question:hover {
background-color: #f1f1f1;
}
/* Submit Button */
#submit-btn {
background-color: #007bff;
color: white;
border-radius: 8px;
transition: transform 0.2s ease-in-out;
margin-top: 15px;
}
#submit-btn:hover {
transform: scale(1.1);
}
/* Animation for message appearance */
@keyframes fadeIn {
0% { opacity: 0; transform: translateY(20px); }
100% { opacity: 1; transform: translateY(0); }
}
"""
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)