|
import gradio as gr |
|
import logging |
|
import os |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
import faiss |
|
from simple_salesforce import Salesforce |
|
from dotenv import load_dotenv |
|
import zipfile |
|
from pathlib import Path |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
sf_username = os.getenv("SF_USERNAME") |
|
sf_password = os.getenv("SF_PASSWORD") |
|
sf_security_token = os.getenv("SF_SECURITY_TOKEN") |
|
sf_instance_url = os.getenv("SF_INSTANCE_URL") |
|
|
|
|
|
if not sf_username or not sf_password or not sf_security_token or not sf_instance_url: |
|
logger.error("β Salesforce credentials are missing from environment variables!") |
|
raise ValueError("Salesforce credentials are not properly set.") |
|
|
|
|
|
try: |
|
sf = Salesforce( |
|
username=sf_username, |
|
password=sf_password, |
|
security_token=sf_security_token, |
|
instance_url=sf_instance_url |
|
) |
|
logger.info("β
Connected to Salesforce") |
|
except Exception as e: |
|
logger.error(f"β Salesforce connection failed: {str(e)}") |
|
raise |
|
|
|
|
|
def extract_zip(zip_path, extract_to): |
|
try: |
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
zip_ref.extractall(extract_to) |
|
logger.info(f"Extracted {zip_path} to {extract_to}") |
|
except Exception as e: |
|
logger.error(f"Failed to extract {zip_path}: {str(e)}") |
|
raise |
|
|
|
def load_documents(folder_path): |
|
documents = [] |
|
sources = [] |
|
for file in Path(folder_path).rglob("*.txt"): |
|
text = file.read_text(encoding="utf-8", errors="ignore") |
|
documents.append(text) |
|
sources.append(file.name) |
|
return documents, sources |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50) |
|
|
|
|
|
model = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
|
|
|
data_dir = Path("./data") |
|
data_dir.mkdir(exist_ok=True) |
|
|
|
doc_folders = [ |
|
("Company_Policies.zip", "Company_Policies"), |
|
("HR_Policies.zip", "Hr_Policies"), |
|
("Contract_Clauses.zip", "Contract_Clauses") |
|
] |
|
|
|
all_chunks = [] |
|
metadata = [] |
|
|
|
for zip_name, folder in doc_folders: |
|
zip_path = Path(zip_name) |
|
if not zip_path.exists(): |
|
logger.error(f"Zip file {zip_name} not found") |
|
raise FileNotFoundError(f"Zip file {zip_name} not found") |
|
extract_path = data_dir / folder |
|
extract_path.mkdir(exist_ok=True) |
|
extract_zip(zip_path, extract_path) |
|
docs, sources = load_documents(extract_path) |
|
if not docs: |
|
logger.error(f"No documents found in {extract_path}") |
|
raise ValueError(f"No documents found in {extract_path}") |
|
for doc, src in zip(docs, sources): |
|
chunks = text_splitter.split_text(doc) |
|
all_chunks.extend(chunks) |
|
src_url = f"https://company.com/{folder}/{src}" |
|
metadata.extend([src_url] * len(chunks)) |
|
|
|
|
|
embeddings = model.encode(all_chunks) |
|
index = faiss.IndexFlatL2(embeddings.shape[1]) |
|
index.add(np.array(embeddings)) |
|
logger.info("FAISS index built successfully") |
|
|
|
|
|
def create_salesforce_record(query, answer, confidence_percentage, source_link): |
|
try: |
|
|
|
confidence_percentage = float(confidence_percentage) |
|
|
|
|
|
data = { |
|
"Query__c": query, |
|
"Answer__c": answer, |
|
"Confidence_Percentage__c": confidence_percentage, |
|
"Document_link__c": source_link, |
|
} |
|
|
|
|
|
response = sf.chat_query_log__c.create(data) |
|
|
|
|
|
if 'id' in response: |
|
record_id = response['id'] |
|
logger.info(f"β
Record created successfully in Salesforce with ID: {record_id}") |
|
return record_id |
|
else: |
|
|
|
logger.error(f"β Failed to create Salesforce record. Response: {response}") |
|
return None |
|
except Exception as e: |
|
|
|
logger.error(f"Error creating Salesforce record: {str(e)}") |
|
return None |
|
|
|
|
|
def answer_query(query): |
|
try: |
|
logger.info(f"Processing query: {query}") |
|
query_embedding = model.encode([query]) |
|
D, I = index.search(np.array(query_embedding), k=3) |
|
top_chunks = [all_chunks[i] for i in I[0]] |
|
top_sources = [metadata[i] for i in I[0]] |
|
distances = D[0] |
|
|
|
relevant_chunks = [ |
|
chunk for chunk, dist in zip(top_chunks, distances) if dist < 0.8 |
|
] |
|
relevant_sources = [ |
|
src for src, dist in zip(top_sources, distances) if dist < 0.8 |
|
] |
|
|
|
if not relevant_chunks: |
|
return "No relevant information found.", "Confidence: 0%", "Source Link: None" |
|
|
|
answer = relevant_chunks[0].strip() |
|
min_distance = min(distances) |
|
confidence_percentage = max(0, 100 - (min_distance * 100)) |
|
source_link = relevant_sources[0] if relevant_sources else "None" |
|
|
|
|
|
record_id = create_salesforce_record(query, answer, confidence_percentage, source_link) |
|
|
|
if record_id: |
|
return ( |
|
answer, |
|
f"Confidence: {confidence_percentage:.2f}%", |
|
f"Source Link: {source_link}", |
|
f"Salesforce Record ID: {record_id}" |
|
) |
|
else: |
|
return ( |
|
answer, |
|
f"Confidence: {confidence_percentage:.2f}%", |
|
f"Source Link: {source_link}", |
|
"Failed to create record in Salesforce" |
|
) |
|
except Exception as e: |
|
logger.error(f"Error in answer_query: {str(e)}") |
|
return f"Error: {str(e)}", "", "", "" |
|
|
|
|
|
def process_question(q, chat_history): |
|
if not q.strip(): |
|
return chat_history + [("User", "Please enter a question.")], "", "" |
|
|
|
answer, confidence, source, record_id = answer_query(q) |
|
chat_history.append(("User", q)) |
|
chat_history.append(("Bot", answer)) |
|
|
|
return chat_history, confidence, source, record_id |
|
|
|
|
|
with gr.Blocks(title="Company Documents Q&A Chatbot", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("## π **Company Policies Q&A Chatbot**") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
question = gr.Textbox( |
|
label="Ask a Question", |
|
placeholder="What are the conditions for permanent employment status?", |
|
lines=1, |
|
interactive=True, |
|
elem_id="user-question", |
|
) |
|
with gr.Column(scale=1): |
|
submit_btn = gr.Button("Submit", variant="primary", elem_id="submit-btn", scale=2) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
chat_history = gr.Chatbot( |
|
label="Chat History", |
|
elem_id="chatbox", |
|
height=400, |
|
show_label=False |
|
) |
|
conf_out = gr.Markdown(label="Confidence", elem_id="confidence") |
|
source_out = gr.Markdown(label="Source Link", elem_id="source-link") |
|
record_out = gr.Markdown(label="Salesforce Record ID", elem_id="salesforce-id") |
|
|
|
submit_btn.click(fn=process_question, inputs=[question, chat_history], outputs=[chat_history, conf_out, source_out, record_out]) |
|
|
|
|
|
demo.css = """ |
|
/* Chatbot Container */ |
|
#chatbox { |
|
background-color: #f9f9f9; |
|
border-radius: 12px; |
|
box-shadow: 0 4px 10px rgba(0, 0, 0, 0.1); |
|
padding: 15px; |
|
overflow-y: auto; |
|
} |
|
|
|
/* User and Bot message bubbles */ |
|
.gradio-chatbot-message-user { |
|
background-color: #0d6efd; |
|
color: white; |
|
border-radius: 15px; |
|
padding: 10px; |
|
margin: 5px 0; |
|
animation: fadeIn 0.5s ease-in-out; |
|
} |
|
|
|
.gradio-chatbot-message-bot { |
|
background-color: #f1f1f1; |
|
color: #333; |
|
border-radius: 15px; |
|
padding: 10px; |
|
margin: 5px 0; |
|
animation: fadeIn 0.5s ease-in-out; |
|
} |
|
|
|
/* Input Box */ |
|
#user-question { |
|
background-color: #e9ecef; |
|
border-radius: 8px; |
|
padding: 10px; |
|
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); |
|
transition: background-color 0.3s ease; |
|
} |
|
|
|
#user-question:hover { |
|
background-color: #f1f1f1; |
|
} |
|
|
|
/* Submit Button */ |
|
#submit-btn { |
|
background-color: #007bff; |
|
color: white; |
|
border-radius: 8px; |
|
transition: transform 0.2s ease-in-out; |
|
margin-top: 15px; |
|
} |
|
|
|
#submit-btn:hover { |
|
transform: scale(1.1); |
|
} |
|
|
|
/* Animation for message appearance */ |
|
@keyframes fadeIn { |
|
0% { opacity: 0; transform: translateY(20px); } |
|
100% { opacity: 1; transform: translateY(0); } |
|
} |
|
""" |
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |
|
|