CXRAG / app.py
ghostai1's picture
Update app.py
fd58b74 verified
import gradio as gr
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
import matplotlib.pyplot as plt
import seaborn as sns
import time
import io
import re
import os
# Embedded call center FAQs (fixed formatting: escaped quotes, consistent rows)
csv_data = """question,answer,call_id,agent_id,timestamp,language
"How do I reset my password?","Go to the login page, click ""Forgot Password,"" and follow the email instructions.",12345,A001,2025-04-01 10:15:23,en
"What are your pricing plans?","We offer Basic ($10/month), Pro ($50/month), and Enterprise (custom).",12346,A002,2025-04-01 10:17:45,en
"How do I contact support?","Email support@partner.com or call +1-800-123-4567.",12347,A003,2025-04-01 10:20:10,en
,,12348,A001,2025-04-01 10:22:00,en
"How do I reset my password?","Duplicate answer.",12349,A002,2025-04-01 10:25:30,en
"help","Contact us.",12350,A004,2025-04-01 10:27:15,en
"What is the refund policy?","Refunds available within 30 days; contact support.",12351,A005,2025-04-01 10:30:00,es
"Invalid query!!!","N/A",12352,A006,2025-04-01 10:32:45,en
"How do I update my billing?","Log in, go to ""Billing,"" and update your payment method.",,A007,2025-04-01 10:35:10,en
"What are pricing plans?","Basic ($10/month), Pro ($50/month).",12353,A002,2025-04-01 10:37:20,en"""
# Data cleanup function
def clean_faqs(df):
original_count = len(df)
cleanup_details = {
'original': original_count,
'nulls_removed': 0,
'duplicates_removed': 0,
'short_removed': 0,
'malformed_removed': 0
}
# Remove nulls
null_rows = df['question'].isna() | df['answer'].isna()
cleanup_details['nulls_removed'] = null_rows.sum()
df = df[~null_rows]
# Remove duplicates
duplicate_rows = df['question'].duplicated()
cleanup_details['duplicates_removed'] = duplicate_rows.sum()
df = df[~duplicate_rows]
# Remove short entries
short_rows = (df['question'].str.len() < 10) | (df['answer'].str.len() < 20)
cleanup_details['short_removed'] = short_rows.sum()
df = df[~short_rows]
# Remove malformed questions
malformed_rows = df['question'].str.contains(r'[!?]{2,}|\b(Invalid|N/A)\b', regex=True, case=False, na=False)
cleanup_details['malformed_removed'] = malformed_rows.sum()
df = df[~malformed_rows]
# Standardize text
df['answer'] = df['answer'].str.replace(r'\bmo\b', 'month', regex=True, case=False)
df['language'] = df['language'].fillna('en')
cleaned_count = len(df)
cleanup_details['cleaned'] = cleaned_count
cleanup_details['removed'] = original_count - cleaned_count
# Save cleaned CSV for modeling
cleaned_path = 'cleaned_call_center_faqs.csv'
df.to_csv(cleaned_path, index=False)
return df, cleanup_details
# Load and clean FAQs
try:
faq_data = pd.read_csv(io.StringIO(csv_data), quotechar='"', escapechar='\\')
faq_data, cleanup_details = clean_faqs(faq_data)
except Exception as e:
raise Exception(f"Failed to load/clean FAQs: {str(e)}")
# Initialize RAG components
try:
embedder = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = embedder.encode(faq_data['question'].tolist(), show_progress_bar=False)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings.astype(np.float32))
except Exception as e:
raise Exception(f"Failed to initialize RAG components: {str(e)}")
# RAG process
def rag_process(query, k=2):
if not query.strip() or len(query) < 5:
return "Invalid query. Please select a question.", "", "", None
start_time = time.perf_counter()
try:
query_embedding = embedder.encode([query], show_progress_bar=False)
embed_time = time.perf_counter() - start_time
except Exception as e:
return f"Error embedding query: {str(e)}", "", "", None
start_time = time.perf_counter()
distances, indices = index.search(query_embedding.astype(np.float32), k)
retrieved_faqs = faq_data.iloc[indices[0]][['question', 'answer']].to_dict('records')
retrieval_time = time.perf_counter() - start_time
start_time = time.perf_counter()
response = retrieved_faqs[0]['answer'] if retrieved_faqs else "Sorry, I couldn't find an answer."
generation_time = time.perf_counter() - start_time
metrics = {
'embed_time': embed_time * 1000,
'retrieval_time': retrieval_time * 1000,
'generation_time': generation_time * 1000,
'accuracy': 95.0 if retrieved_faqs else 0.0
}
return response, retrieved_faqs, metrics
# Plot RAG pipeline
def plot_metrics(metrics):
data = pd.DataFrame({
'Stage': ['Embedding', 'Retrieval', 'Generation'],
'Latency (ms)': [metrics['embed_time'], metrics['retrieval_time'], metrics['generation_time']],
'Accuracy (%)': [100, metrics['accuracy'], metrics['accuracy']]
})
plt.figure(figsize=(10, 6)) # Increased size for better readability
sns.set_style("whitegrid")
sns.set_palette("muted")
ax1 = sns.barplot(x='Stage', y='Latency (ms)', data=data, color='skyblue')
ax1.set_ylabel('Latency (ms)', color='skyblue')
ax1.tick_params(axis='y', labelcolor='skyblue')
ax2 = ax1.twinx()
sns.lineplot(x='Stage', y='Accuracy (%)', data=data, marker='o', color='lightblue', linewidth=2)
ax2.set_ylabel('Accuracy (%)', color='lightblue')
ax2.tick_params(axis='y', labelcolor='lightblue')
plt.title('RAG Pipeline: Latency and Accuracy')
plt.tight_layout()
plt.savefig('rag_plot.png')
plt.close()
return 'rag_plot.png'
# Gradio interface with stacked buttons and single output
def chat_interface(query):
try:
response, retrieved_faqs, metrics = rag_process(query)
plot_path = plot_metrics(metrics)
faq_text = "\n".join([f"Q: {faq['question']}\nA: {faq['answer']}" for faq in retrieved_faqs])
cleanup_stats = (
f"Cleaned FAQs: {cleanup_details['cleaned']} "
f"(removed {cleanup_details['removed']} junk entries: "
f"{cleanup_details['nulls_removed']} nulls, "
f"{cleanup_details['duplicates_removed']} duplicates, "
f"{cleanup_details['short_removed']} short, "
f"{cleanup_details['malformed_removed']} malformed)"
)
return response, faq_text, cleanup_stats, plot_path
except Exception as e:
return f"Error: {str(e)}", "", "", None
# Dark theme CSS with improved styling
custom_css = """
body {
background: linear-gradient(135deg, #1a1a1a 0%, #2a2a2a 100%);
color: #e0e0e0;
font-family: 'Arial', sans-serif;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
margin: 0;
}
.gr-box {
background: #3a3a3a;
border: 1px solid #4a4a4a;
border-radius: 8px;
padding: 20px; /* Increased padding for better spacing */
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.3);
}
.gr-button {
background: #1e90ff;
color: white;
border-radius: 5px;
padding: 12px 20px; /* Slightly larger padding for buttons */
margin: 8px 0; /* Increased margin for better spacing */
width: 100%;
text-align: center;
transition: background 0.3s ease;
font-size: 16px;
}
.gr-button:hover {
background: #1c86ee;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
}
.gr-textbox {
background: #2f2f2f;
color: #e0e0e0;
border: 1px solid #4a4a4a;
border-radius: 5px;
margin-bottom: 15px; /* Increased margin for better spacing */
font-size: 16px; /* Larger font size for readability */
padding: 15px; /* Increased padding for larger textboxes */
min-height: 120px; /* Increased height for better readability */
width: 100%; /* Ensure full width */
}
.gr-image {
width: 100%; /* Ensure the plot takes full width of container */
height: auto; /* Maintain aspect ratio */
max-height: 400px; /* Increased max height for larger plot */
}
#app-container {
max-width: 900px; /* Slightly wider container for better balance */
width: 100%;
padding: 20px;
background: #252525;
border-radius: 12px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);
}
#button-container {
display: flex;
flex-direction: column;
gap: 15px; /* Increased gap for better spacing */
padding: 20px; /* Increased padding for better alignment */
background: #303030;
border-radius: 8px;
align-items: center;
width: 100%; /* Full width within parent column */
}
#output-container {
background: #303030;
padding: 20px; /* Increased padding for larger output fields */
border-radius: 8px;
width: 100%; /* Full width within parent column */
}
.text-center {
text-align: center;
margin-bottom: 20px;
}
#app-row {
display: flex;
gap: 30px; /* Increased gap for better separation */
justify-content: space-between;
align-items: stretch; /* Ensure columns stretch to same height */
}
"""
# Get unique questions for buttons (after cleanup)
unique_questions = faq_data['question'].tolist()
with gr.Blocks(css=custom_css) as demo:
with gr.Column(elem_id="app-container"):
gr.Markdown("# Customer Experience Bot Demo", elem_classes="text-center")
gr.Markdown("Select a question to see the bot's response, retrieved FAQs, and call center data cleanup stats.", elem_classes="text-center")
# Layout: outputs on left, buttons on right
with gr.Row(elem_id="app-row"):
# Single output panel (left 2/3)
with gr.Column(elem_id="output-container", scale=2): # Increased scale for larger output area
response_output = gr.Textbox(label="Bot Response", elem_id="response-output")
faq_output = gr.Textbox(label="Retrieved FAQs", elem_id="faq-output")
cleanup_output = gr.Textbox(label="Data Cleanup Stats", elem_id="cleanup-output")
plot_output = gr.Image(label="RAG Pipeline Metrics", elem_id="plot-output")
# Stacked buttons (right 1/3)
with gr.Column(elem_id="button-container", scale=1): # Adjusted scale for buttons
for question in unique_questions:
gr.Button(question).click(
fn=chat_interface,
inputs=gr.State(value=question),
outputs=[
response_output,
faq_output,
cleanup_output,
plot_output
]
)
demo.launch()