Spaces:
Running
Running
import os | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import torch | |
import re | |
import warnings | |
import time | |
import json | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig | |
from sentence_transformers import SentenceTransformer, util, CrossEncoder | |
import gspread | |
# from google.colab import auth | |
from google.auth import default | |
from tqdm import tqdm | |
from ddgs import DDGS # Updated import | |
import spacy | |
from datetime import date, timedelta, datetime # Import datetime | |
from dateutil.relativedelta import relativedelta # Corrected typo | |
import traceback # Import traceback | |
import base64 # Import base64 | |
import dateparser # Import dateparser | |
from dateparser.search import search_dates | |
import pytz # Import pytz for timezone handling | |
# from google.colab import userdata # Import userdata | |
# Suppress warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# Define global variables and load secrets | |
# Load HF_TOKEN from userdata as well | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Add a print statement to check if HF_TOKEN is loaded | |
print(f"HF_TOKEN loaded: {'*' * len(HF_TOKEN) if HF_TOKEN else 'None'}") | |
SHEET_ID = "19ipxC2vHYhpXCefpxpIkpeYdI43a1Ku2kYwecgUULIw" | |
# Use userdata.get() for Google Credentials | |
GOOGLE_BASE64_CREDENTIALS = os.getenv("GOOGLE_BASE64_CREDENTIALS") | |
# Get the API key from Space Secrets | |
# Make sure this matches the name you used in Hugging Face Space Secrets | |
SECRET_API_KEY = os.getenv("APP_API_KEY") | |
# Add a print statement to check if SECRET_API_KEY is loaded | |
print(f"SECRET_API_KEY loaded: {'*' * len(SECRET_API_KEY) if SECRET_API_KEY else 'None'}") | |
if not SECRET_API_KEY: | |
print("Warning: APP_API_KEY secret not set. API key validation will fail.") | |
elif not SECRET_API_KEY.startswith("fs_"): | |
print("Warning: APP_API_KEY secret does not start with 'fs_'. Please check your secret.") | |
# Initialize InferenceClient | |
# client = InferenceClient("google/gemma-2-9b-it", token=HF_TOKEN) | |
# client = InferenceClient("meta-llama/Llama-4-Scout-17B-16E-Instruct", token=HF_TOKEN) | |
# Initialize InferenceClient using the loaded HF_TOKEN | |
client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", token=HF_TOKEN) | |
# Load spacy model for sentence splitting | |
nlp = None | |
try: | |
nlp = spacy.load("en_core_web_sm") | |
print("SpaCy model 'en_core_web_sm' loaded.") | |
except OSError: | |
print("SpaCy model 'en_core_web_sm' not found. Downloading...") | |
try: | |
os.system("python -m spacy download en_core_web_sm") | |
nlp = spacy.load("en_core_web_sm") | |
print("SpaCy model 'en_core_web_sm' downloaded and loaded.") | |
except Exception as e: | |
print(f"Failed to download or load SpaCy model: {e}") | |
# Load SentenceTransformer for RAG/business info retrieval and semantic detection | |
embedder = None | |
try: | |
print("Attempting to load Sentence Transformer (sentence-transformers/paraphrase-MiniLM-L6-v2)...") | |
# Use the model provided by the user for semantic detection as well | |
embedder = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2") # Or 'all-MiniLM-L6-v2' if preferred | |
print("Sentence Transformer loaded.") | |
except Exception as e: | |
print(f"Error loading Sentence Transformer: {e}") | |
# Load a Cross-Encoder model for re-ranking retrieved documents | |
reranker = None | |
try: | |
print("Attempting to load Cross-Encoder Reranker (cross-encoder/ms-marco-MiniLM-L6-v2)...") | |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2') | |
print("Cross-Encoder Reranker loaded.") | |
except Exception as e: | |
print(f"Error loading Cross-Encoder Reranker: {e}") | |
print("Please ensure the model identifier 'cross-encoder/ms-marco-MiniLM-L6-v2' is correct and accessible on Hugging Face Hub.") | |
print(traceback.format_exc()) | |
reranker = None | |
# Google Sheets Authentication | |
gc = None # Global variable for gspread client | |
def authenticate_google_sheets(): | |
"""Authenticates with Google Sheets using base64 encoded credentials.""" | |
global gc | |
print("Authenticating Google Account...") | |
if not GOOGLE_BASE64_CREDENTIALS: | |
print("Error: GOOGLE_BASE64_CREDENTIALS secret not found.") | |
return False | |
try: | |
# Decode the base64 credentials | |
credentials_json = base64.b64decode(GOOGLE_BASE64_CREDENTIALS).decode('utf-8') | |
credentials = json.loads(credentials_json) | |
# Authenticate using service account from dictionary | |
gc = gspread.service_account_from_dict(credentials) | |
print("Google Sheets authentication successful via service account.") | |
return True | |
except Exception as e: | |
print(f"Google Sheets authentication failed: {e}") | |
print(traceback.format_exc()) | |
print("Please ensure your GOOGLE_BASE64_CREDENTIALS secret is correctly set and contains valid service account credentials.") | |
return False | |
# Google Sheets Data Loading and Embedding | |
data = [] # Global variable to store loaded data | |
descriptions_for_embedding = [] | |
embeddings = torch.tensor([]) | |
business_info_available = False # Flag to indicate if business info was loaded successfully | |
def load_business_info(): | |
"""Loads business information from Google Sheet and creates embeddings.""" | |
global data, descriptions_for_embedding, embeddings, business_info_available | |
business_info_available = False # Reset flag | |
if gc is None: | |
print("Skipping Google Sheet loading: Google Sheets client not authenticated.") | |
return | |
if not SHEET_ID: | |
print("Error: SHEET_ID not set.") | |
return | |
try: | |
sheet = gc.open_by_key(SHEET_ID).sheet1 | |
print(f"Successfully opened Google Sheet with ID: {SHEET_ID}") | |
data_records = sheet.get_all_records() | |
if not data_records: | |
print(f"Warning: No data records found in Google Sheet with ID: {SHEET_ID}") | |
data = [] | |
descriptions_for_embedding = [] | |
else: | |
# Filter out rows missing 'Service' or 'Description' | |
filtered_data = [row for row in data_records if row.get('Service') and row.get('Description')] | |
if not filtered_data: | |
print("Warning: Filtered data is empty after checking for 'Service' and 'Description'.") | |
data = [] | |
descriptions_for_embedding = [] | |
else: | |
data = filtered_data | |
# Use BOTH Service and Description for embedding | |
descriptions_for_embedding = [f"Service: {row['Service']}. Description: {row['Description']}" for row in data] | |
# Only encode if descriptions_for_embedding are found and embedder is available | |
if descriptions_for_embedding and embedder is not None: | |
print("Encoding descriptions...") | |
try: | |
embeddings = embedder.encode(descriptions_for_embedding, convert_to_tensor=True) | |
print("Encoding complete.") | |
business_info_available = True | |
except Exception as e: | |
print(f"Error during description encoding: {e}") | |
embeddings = torch.tensor([]) | |
business_info_available = False | |
else: | |
print("Skipping encoding descriptions: No descriptions found or embedder not available.") | |
embeddings = torch.tensor([]) | |
business_info_available = False | |
print(f"Loaded {len(descriptions_for_embedding)} entries from Google Sheet for embedding/RAG.") | |
if not business_info_available: | |
print("Business information retrieval (RAG) is NOT available.") | |
except gspread.exceptions.SpreadsheetNotFound: | |
print(f"Error: Google Sheet with ID '{SHEET_ID}' not found.") | |
print("Please check the SHEET_ID and ensure your authenticated Google Account has access to this sheet.") | |
business_info_available = False | |
except Exception as e: | |
print(f"An error occurred while accessing the Google Sheet: {e}") | |
print(traceback.format_exc()) | |
business_info_available = False | |
# Business Info Retrieval (RAG) | |
def retrieve_business_info(query: str, top_n: int = 3) -> list: | |
""" | |
Retrieves relevant business information from loaded data based on a query. | |
Args: | |
query: The user's query string. | |
top_n: The number of top relevant entries to retrieve. | |
Returns: | |
A list of dictionaries, where each dictionary is a relevant row from the | |
Google Sheet data. Returns an empty list if RAG is not available or | |
no relevant information is found. | |
""" | |
global data | |
if not business_info_available or embedder is None or not descriptions_for_embedding or not data: | |
print("Business information retrieval is not available or data is empty.") | |
return [] | |
try: | |
query_embedding = embedder.encode(query, convert_to_tensor=True) | |
cosine_scores = util.cos_sim(query_embedding, embeddings)[0] | |
top_results_indices = torch.topk(cosine_scores, k=min(top_n, len(data)))[1].tolist() | |
top_results = [data[i] for i in top_results_indices] | |
if reranker is not None and top_results: | |
print("Re-ranking top results...") | |
rerank_pairs = [(query, descriptions_for_embedding[i]) for i in top_results_indices] | |
rerank_scores = reranker.predict(rerank_pairs) | |
reranked_indices = sorted(range(len(rerank_scores)), key=lambda i: rerank_scores[i], reverse=True) | |
reranked_results = [top_results[i] for i in reranked_indices] | |
print("Re-ranking complete.") | |
return reranked_results | |
else: | |
return top_results | |
except Exception as e: | |
print(f"Error during business information retrieval: {e}") | |
print(traceback.format_exc()) | |
return [] | |
# Function to perform DuckDuckGo Search and return results with URLs | |
def perform_duckduckgo_search(query: str, max_results: int = 5): # Reduced max_results for multi-part queries | |
""" | |
Performs a search using DuckDuckGo and returns a list of dictionaries. | |
Includes a delay to avoid rate limits. | |
Returns an empty list and prints an error if search fails. | |
""" | |
print(f"Executing Tool: perform_duckduckgo_search with query='{query}')") | |
search_results_list = [] | |
try: | |
time.sleep(1) | |
with DDGS() as ddgs: | |
search_query = query.strip() | |
if not search_query or len(search_query.split()) < 2: | |
print(f"Skipping search for short query: '{search_query}'") | |
return [] | |
print(f"Sending search query to DuckDuckGo: '{search_query}'") | |
results_generator = ddgs.text(search_query, max_results=max_results) | |
results_found = False | |
for r in results_generator: | |
search_results_list.append(r) | |
results_found = True | |
print(f"Raw results from DuckDuckGo: {search_results_list}") | |
if not results_found and max_results > 0: | |
print(f"DuckDuckGo search for '{search_query}' returned no results.") | |
elif results_found: | |
print(f"DuckDuckGo search for '{search_query}' completed. Found {len(search_results_list)} results.") | |
except Exception as e: | |
print(f"Error during Duckduckgo search for '{search_query if 'search_query' in locals() else query}': {e}") | |
print(traceback.format_exc()) | |
return [] | |
return search_results_list | |
# Define the new semantic date/time detection and calculation function using dateparser | |
def perform_date_calculation(query: str) -> str or None: | |
""" | |
Analyzes query for date/time information using dateparser. | |
If dateparser finds a date, it returns a human-friendly response string. | |
Otherwise, it returns None. | |
It is designed to handle multiple languages and provide the time for East Africa (Tanzania). | |
""" | |
print(f"Executing Tool: perform_date_calculation with query='{query}') using dateparser.search_dates") | |
try: | |
eafrica_tz = pytz.timezone('Africa/Dar_es_Salaam') | |
now = datetime.now(eafrica_tz) | |
except pytz.UnknownTimeZoneError: | |
print("Error: Unknown timezone 'Africa/Dar_es_Salaam'. Using default system time.") | |
now = datetime.now() | |
try: | |
# Try parsing with Swahili first, then English | |
found = search_dates( | |
query, | |
settings={ | |
"PREFER_DATES_FROM": "future", | |
"RELATIVE_BASE": now | |
}, | |
languages=['sw', 'en'] # Prioritize Swahili | |
) | |
if not found: | |
print("dateparser.search_dates could not parse any date/time.") | |
return None | |
text_snippet, parsed = found[0] | |
print(f"dateparser.search_dates found: text='{text_snippet}', parsed='{parsed}'") | |
is_swahili = any(swahili_phrase in query.lower() for swahili_phrase in ['tarehe', 'siku', 'saa', 'muda', 'leo', 'kesho', 'jana', 'ngapi', 'gani', 'mwezi', 'mwaka', 'habari', 'mambo', 'shikamoo', 'karibu', 'asante']) | |
# Check for specific Swahili greetings and respond appropriately | |
if is_swahili: | |
query_lower = query.lower().strip() | |
if query_lower in ['habari', 'mambo', 'habari gani']: | |
return "Nzuri! Habari zako?" # Common Swahili response to greetings | |
elif query_lower in ['shikamoo']: | |
return "Marahaba!" # Response to Shikamoo | |
elif query_lower in ['asante']: | |
return "Karibu!" # Response to Asante | |
elif query_lower in ['karibu']: | |
return "Asante!" # Response to Karibu | |
# Handle timezone information | |
if now.tzinfo is not None and parsed.tzinfo is None: | |
parsed = now.tzinfo.localize(parsed) | |
elif now.tzinfo is None and parsed.tzinfo is not None: | |
parsed = parsed.replace(tzinfo=None) | |
# Check if the parsed date is today and time is close to now or midnight | |
if parsed.date() == now.date(): | |
# Consider it "now" if within a small time window or if no specific time was parsed (midnight) | |
if abs((parsed - now).total_seconds()) < 60 or parsed.time() == datetime.min.time(): | |
print("Query parsed to today's date and time is close to 'now' or midnight, returning current time/date.") | |
if is_swahili: | |
return f"Kwa saa za Afrika Mashariki (Tanzania), tarehe ya leo ni {now.strftime('%A, %d %B %Y')} na saa ni {now.strftime('%H:%M:%S')}." | |
else: | |
return f"In East Africa (Tanzania), the current date is {now.strftime('%A, %d %B %Y')} and the time is {now.strftime('%H:%M:%S')}." | |
else: | |
print(f"Query parsed to a specific time today: {parsed.strftime('%H:%M:%S')}") | |
if is_swahili: | |
return f"Hiyo inafanyika leo, {parsed.strftime('%A, %d %B %Y')}, saa {parsed.strftime('%H:%M:%S')} saa za Afrika Mashariki." | |
else: | |
return f"That falls on today, {parsed.strftime('%A, %d %B %Y')}, at {parsed.strftime('%H:%M:%S')} East Africa Time." | |
else: | |
print(f"Query parsed to a specific date: {parsed.strftime('%A, %d %B %Y')} at {parsed.strftime('%H:%M:%S')}") | |
time_str = parsed.strftime('%H:%M:%S') | |
date_str = parsed.strftime('%A, %d %B %Y') | |
if parsed.tzinfo: | |
tz_name = parsed.tzinfo.tzname(parsed) or 'UTC' | |
if is_swahili: | |
return f"Hiyo inafanyika tarehe {date_str} saa {time_str} {tz_name}." | |
else: | |
return f"That falls on {date_str} at {time_str} {tz_name}." | |
else: | |
if is_swahili: | |
return f"Hiyo inafanyika tarehe {date_str} saa {time_str}." | |
else: | |
return f"That falls on {date_str} at {time_str}." | |
except Exception as e: | |
print(f"Error during dateparser.search_dates execution: {e}") | |
print(traceback.format_exc()) | |
return f"An error occurred while parsing date/time: {e}" | |
# Function to determine if a query requires a tool or can be answered directly | |
def determine_tool_usage(query: str) -> str: | |
""" | |
Analyzes the query to determine if a specific tool is needed. | |
Returns the name of the tool ('duckduckgo_search', 'business_info_retrieval', | |
'date_calculation') or 'none' if no specific tool is clearly indicated. | |
Prioritizes business information retrieval, then specific tools based on keywords | |
and LLM judgment. | |
""" | |
query_lower = query.lower() | |
# Check for specific Swahili greetings or conversational phrases that should be handled by date_calculation | |
swahili_conversational_phrases = ['habari', 'mambo', 'shikamoo', 'karibu', 'asante', 'habari gani'] | |
# Corrected list comprehension | |
if any(swahili_phrase in query_lower for swahili_phrase in swahili_conversational_phrases): | |
print(f"Detected a Swahili conversational phrase: '{query}'. Using 'date_calculation' tool for initial handling.") | |
return "date_calculation" | |
# 1. Prioritize Business Info Retrieval if RAG is available | |
if business_info_available: | |
messages_business_check = [{"role": "user", "content": f"Does the following query ask about a specific person, service, offering, or description that is likely to be found *only* within a specific business's internal knowledge base, and not general knowledge? For example, questions about 'Salum' or 'Jackson Kisanga' are likely business-related, while questions about 'the current president of the USA' or 'who won the Ballon d'Or' are general knowledge. Answer only 'yes' or 'no'. Query: {query}"}] | |
try: | |
business_check_response = client.chat_completion( | |
messages=messages_business_check, | |
max_tokens=10, | |
temperature=0.1 | |
).choices[0].message.content.strip().lower() | |
# Ensure the response explicitly contains "yes" and is not just a substring match | |
if business_check_response == "yes": | |
print(f"Detected as specific business info query based on LLM check: '{query}'") | |
return "business_info_retrieval" | |
else: | |
print(f"LLM check indicates not a specific business info query: '{query}'") | |
except Exception as e: | |
print(f"Error during LLM call for business info check for query '{query}': {e}") | |
print(traceback.format_exc()) | |
print(f"Proceeding without business info check for query '{query}' due to error.") | |
# 2. Check for Date Calculation (only if not a simple greeting handled above) | |
date_time_check_result = perform_date_calculation(query) # Re-run date_calculation to check for actual dates | |
if date_time_check_result is not None and not any(phrase in query_lower for phrase in swahili_conversational_phrases): | |
print(f"Detected as date/time calculation query based on dateparser result for: '{query}'") | |
return "date_calculation" | |
# 3. Use LLM to determine if DuckDuckGo search is needed | |
messages_tool_determination_search = [{"role": "user", "content": f"Does the following query require searching the web for current or general knowledge information (e.g., news, facts, definitions, current events)? Respond ONLY with 'duckduckgo_search' or 'none'. Query: {query}"}] | |
try: | |
search_determination_response = client.chat_completion( | |
messages=messages_tool_determination_search, | |
max_tokens=20, | |
temperature=0.1, | |
top_p=0.9 | |
).choices[0].message.content or "" | |
response_lower = search_determination_response.strip().lower() | |
if "duckduckgo_search" in response_lower: | |
print(f"Model-determined tool for '{query}': 'duckduckgo_search'") | |
return "duckduckgo_search" | |
else: | |
print(f"Model-determined tool for '{query}': 'none' (for search)") | |
except Exception as e: | |
print(f"Error during LLM call for search tool determination for query '{query}': {e}") | |
print(traceback.format_exc()) | |
print(f"Proceeding without search tool check for query '{query}' due to error.") | |
# 4. If none of the specific tools are determined, default to 'none' | |
print(f"No specific tool determined for '{query}'. Defaulting to 'none'.") | |
return "none" | |
# Function to generate text using the LLM, incorporating tool results if available | |
def generate_text(prompt: str, tool_results: dict = None, chat_history: list[dict] = None) -> str: | |
""" | |
Generates text using the configured LLM, optionally incorporating tool results and chat history. | |
Args: | |
prompt: The initial prompt for the LLM (the user's latest query). | |
tool_results: A dictionary containing results from executed tools. | |
Keys are tool names, values are their outputs. | |
chat_history: The history of the conversation as a list of dictionaries | |
(as provided by Gradio ChatInterface with type="messages"). | |
Returns: | |
The generated text from the LLM. | |
""" | |
# Add persona instructions to the beginning of the prompt | |
persona_instructions = """You are absa_ai, an AI developed on August 7, 2025, by the absa team. Your knowledge about business data comes from the company's internal Google Sheet. | |
You are a friendly and helpful chatbot. Respond to greetings appropriately (e.g., "Hello!", "Hi there!", "Habari!"). If the user uses Swahili greetings or simple conversational phrases, respond in Swahili. Otherwise, respond in English unless the query is clearly in Swahili. Handle conversational flow and ask follow-up questions when appropriate. | |
If the user asks a question about other companies or general knowledge, answer their question. However, subtly remind them that your primary expertise and purpose are related to Absa-specific information. | |
""" | |
# Build the messages list for the chat completion API | |
messages = [{"role": "user", "content": persona_instructions}] # Start with the persona instructions | |
if chat_history: | |
print("Including chat history in LLM prompt.") | |
# Iterate through the chat_history provided by Gradio (list of dictionaries) | |
# Add only 'user' and 'assistant' roles to the LLM context | |
for message_dict in chat_history: | |
role = message_dict.get("role") | |
content = message_dict.get("content") | |
if role in ["user", "assistant"] and content is not None: | |
messages.append({"role": role, "content": content}) | |
# Add the current user prompt and tool results | |
current_user_content = prompt | |
if tool_results and any(tool_results.values()): | |
current_user_content += "\n\nTool Results:\n" | |
for question, results in tool_results.items(): | |
if results: | |
current_user_content += f"--- Results for: {question} ---\n" | |
if isinstance(results, list): | |
for i, result in enumerate(results): | |
if isinstance(result, dict) and 'Service' in result and 'Description' in result: | |
current_user_content += f"Business Info {i+1}:\nService: {result.get('Service', 'N/A')}\nDescription: {result.get('Description', 'N/A')}\n\n" | |
elif isinstance(result, dict) and 'url' in result: | |
current_user_content += f"Search Result {i+1}:\nTitle: {result.get('title', 'N/A')}\nURL: {result.get('url', 'N/A')}\nSnippet: {result.get('body', 'N/A')}\n\n" | |
else: | |
current_user_content += f"{result}\n\n" | |
elif isinstance(results, dict): | |
for key, value in results.items(): | |
current_user_content += f"{key}: {value}\n" | |
current_user_content += "\n" | |
else: | |
current_user_content += f"{results}\n\n" | |
current_user_content += "Based on the provided tool results and the conversation history, answer the user's latest query. If a question was answered by a tool, use the tool's result directly in your response. Maintain the language of the original query if possible, especially for simple greetings or direct questions answered by tools." | |
print("Added tool results and instruction to final prompt.") | |
else: | |
current_user_content += "Based on the conversation history, answer the user's latest query." | |
print("No tool results to add to final prompt, relying on conversation history.") | |
messages.append({"role": "user", "content": current_user_content}) | |
print(f"Sending messages to LLM:\n---\n{messages}\n---") | |
generation_config = { | |
"temperature": 0.7, | |
"max_new_tokens": 500, | |
"top_p": 0.95, | |
"top_k": 50, | |
"do_sample": True, | |
} | |
try: | |
response = client.chat_completion( | |
messages=messages, # Pass the list of messages | |
max_tokens=generation_config.get("max_new_tokens", 512), | |
temperature=generation_config.get("temperature", 0.7), | |
top_p=generation_config.get("top_p", 0.95) | |
).choices[0].message.content or "" | |
print("LLM generation successful using chat_completion.") | |
return response | |
except Exception as e: | |
print(f"Error during final LLM generation: {e}") | |
print(traceback.format_exc()) | |
return "An error occurred while generating the final response." | |
def log_conversation(user_query: str, model_response: str, tool_details: dict = None, user_id: str = None): | |
""" | |
Logs conversation data (query, response, timestamp, optional details) to a file. | |
""" | |
timestamp = datetime.now().isoformat() # Corrected line | |
log_entry = { | |
"timestamp": timestamp, | |
"user_query": user_query, | |
"model_response": model_response | |
} | |
if tool_details: | |
log_entry["tool_details"] = tool_details | |
if user_id: | |
log_entry["user_id"] = user_id | |
log_file = "conversation_log.jsonl" | |
try: | |
with open(log_file, "a") as f: | |
f.write(json.dumps(log_entry) + "\n") | |
# print(f"Conversation data logged to {log_file}") # Keep this for debugging if needed, but maybe remove for production | |
except IOError as e: | |
print(f"Error writing to log file {log_file}: {e}") | |
# Main chat function with query breakdown and tool execution per question | |
def chat(query: str, chat_history: list[dict], api_key: str): # Added api_key back to signature | |
""" | |
Processes user queries by breaking down multi-part queries, determining and | |
executing appropriate tools for each question, and synthesizing results | |
using the LLM. Prioritizes business information retrieval. | |
Requires a valid API key (uses the globally loaded SECRET_API_KEY). | |
""" | |
# Add print statements to show received arguments | |
print(f"chat function received:") | |
print(f" query: {query}") | |
# Validate the API Key using the globally loaded SECRET_API_KEY | |
# No longer relying on api_key being passed as an argument from the UI | |
print(f" Validating against SECRET_API_KEY: {'*' * len(SECRET_API_KEY) if SECRET_API_KEY else 'None'}") | |
print(f" chat_history: {chat_history}") | |
print(f" api_key from UI: {'*' * len(api_key) if api_key else 'None'}") | |
# Validate the API Key using the globally available SECRET_API_KEY | |
if not SECRET_API_KEY: | |
print("Error: APP_API_KEY secret not set in Hugging Face Space Secrets.") | |
return "API key validation failed: Application not configured correctly. APP_API_KEY secret is missing." | |
# Validate the API key passed from the UI | |
if api_key != SECRET_API_KEY: | |
print("Error: API key from UI does not match SECRET_API_KEY.") | |
# Log the failed attempt | |
log_conversation( | |
user_query=query, | |
model_response="API key validation failed: Invalid API key provided.", | |
tool_details={"validation_status": "failed", "reason": "invalid_api_key"}, | |
user_id="unknown" # Or attempt to derive a user ID if available before validation | |
) | |
return "API key validation failed: Invalid API key provided." | |
# If the SECRET_API_KEY is loaded and matches the UI key, proceed with the rest of the function logic | |
# Step 1: Query Breakdown | |
print("\n--- Breaking down query ---") | |
prompt_for_question_breakdown = f""" | |
Analyze the following query and list each distinct question found within it. | |
Present each question on a new line, starting with a hyphen. | |
Query: {query} | |
""" | |
try: | |
messages_question_breakdown = [{"role": "user", "content": prompt_for_question_breakdown}] | |
question_breakdown_response = client.chat_completion( | |
messages=messages_question_breakdown, | |
max_tokens=100, | |
temperature=0.1, | |
top_p=0.9 | |
).choices[0].message.content or "" | |
individual_questions = [line.strip() for line in question_breakdown_response.split('\n') if line.strip()] | |
# Remove any notes the LLM might add during breakdown | |
cleaned_questions = [re.sub(r'^[-*]?\s*', '', q) for q in individual_questions if not q.strip().lower().startswith('note:')] | |
print("Individual questions identified:") | |
for q in cleaned_questions: | |
print(f"- {q}") | |
except Exception as e: | |
print(f"Error during LLM call for question breakdown: {e}") | |
print(traceback.format_exc()) | |
cleaned_questions = [query] # Fallback to treating the whole query as one question | |
# Step 2: Tool Determination per Question | |
print("\n--- Determining tools per question ---") | |
determined_tools = {} | |
for question in cleaned_questions: | |
print(f"\nAnalyzing question for tool determination: '{question}'") | |
determined_tools[question] = determine_tool_usage(question) | |
print(f"Determined tool for '{question}': '{determined_tools[question]}'") | |
print("\nSummary of determined tools per question:") | |
for question, tool in determined_tools.items(): | |
print(f"'{question}': '{tool}'") | |
# Step 3: Execute Tools and Step 4: Synthesize Results | |
print("\n--- Executing tools and collecting results ---") | |
tool_results = {} | |
for question, tool in determined_tools.items(): | |
print(f"\nExecuting tool '{tool}' for question: '{question}')") | |
result = None | |
if tool == "date_calculation": | |
result = perform_date_calculation(question) | |
tool_results[question] = result # Store result even if None for logging | |
elif tool == "duckduckgo_search": | |
result = perform_duckduckgo_search(question) | |
tool_results[question] = result # Store result even if None for logging | |
elif tool == "business_info_retrieval": | |
result = retrieve_business_info(question) | |
tool_results[question] = result # Store result even if None for logging | |
elif tool == "none": | |
# If tool is 'none', the LLM will answer this part using its internal knowledge | |
# in the final response generation step. We don't need a specific tool result here. | |
print(f"Skipping tool execution for question: '{question}' as tool is 'none'. LLM will handle.") | |
tool_results[question] = "none" # Indicate that no tool was used | |
print("\n--- Collected Tool Results ---") | |
if tool_results: | |
for question, result in tool_results.items(): | |
print(f"\nQuestion: {question}") | |
print(f"Result: {result}") | |
else: | |
print("No tool results were collected.") | |
print("\n--------------------------") | |
# Step 5: Final Response Generation | |
print("\n--- Generating final response ---") | |
# Pass the chat_history (which is a list of dictionaries when using type="messages") | |
final_response = generate_text(query, tool_results, chat_history) | |
print("\n--- Final Response from LLM ---") | |
print(final_response) | |
print("\n----------------------------") | |
# Log the conversation turn AFTER the final response is generated | |
try: | |
# Attempt to extract user_id from chat_history if available or use a default | |
# For a simple case with Gradio ChatInterface type="messages", user_id is not | |
# directly available in the `chat` signature unless you add it. | |
# For now, we'll use a placeholder or try to extract from history if a user ID is passed in the message content. | |
# In a real production system, you'd get the user ID from your authentication system. | |
user_id_to_log = "anonymous" # Default user ID | |
if chat_history: | |
# This is a basic attempt to find a user ID in the history, | |
# but a real system would use proper authentication. | |
for turn in chat_history: | |
if turn.get("role") == "user" and "user_id:" in turn.get("content", "").lower(): | |
match = re.search(r"user_id:\s*(\S+)", turn.get("content", ""), re.IGNORECASE) | |
if match: | |
user_id_to_log = match.group(1) | |
break # Found a user ID, stop searching | |
# Prepare tool details for logging | |
logged_tool_details = {} | |
for question, tool_name in determined_tools.items(): | |
logged_tool_details[question] = { | |
"tool_used": tool_name, | |
"raw_output": tool_results.get(question) # Include the raw output from the tool | |
} | |
log_conversation( | |
user_query=query, | |
model_response=final_response, | |
tool_details=logged_tool_details, | |
user_id=user_id_to_log # Log the determined user ID | |
) | |
except Exception as e: | |
print(f"Error during conversation logging after response generation: {e}") | |
print(traceback.format_exc()) | |
# Return only the latest AI response as a string for Gradio's ChatInterface | |
return final_response | |
# Keep the Gradio interface setup as is for now | |
if __name__ == "__main__": | |
# Authenticate Google Sheets when the script starts | |
authenticate_google_sheets() | |
# Load business info after authentication | |
load_business_info() | |
# Check if spacy model, embedder, and reranker loaded correctly | |
if nlp is None: | |
print("Warning: SpaCy model not loaded. Sentence splitting may not work correctly.") | |
if embedder is None: | |
print("Warning: Sentence Transformer (embedder) not loaded. RAG will not be available.") | |
if reranker is None: | |
print("Warning: Cross-Encoder Reranker not loaded. Re-ranking of RAG results will not be performed.") | |
if not business_info_available: | |
print("Warning: Business information (Google Sheet data) not loaded successfully. " | |
"RAG will not be available. Please ensure the GOOGLE_BASE64_CREDENTIALS secret is set correctly.") | |
print("Launching Gradio Interface...") | |
import gradio as gr | |
DESCRIPTION = """ | |
# LLM with Tools (DuckDuckGo Search, Date Calculation, Business Info RAG) | |
Ask me anything! I can perform web searches, calculate dates, and retrieve business information. | |
""" | |
# Update the Gradio ChatInterface to include an API key input | |
demo = gr.ChatInterface( | |
fn=chat, | |
stop_btn=None, | |
examples=[ | |
["Hello there! How are you doing?"], | |
["What is the current time in East Africa?"], | |
["Tell me about the 'Project Management' service from Absa."], | |
["Search the web for the latest news on AI."], | |
["Habari!"], | |
["What is the date next Tuesday?"], | |
], | |
cache_examples=False, | |
type="messages", | |
description=DESCRIPTION, | |
fill_height=True, | |
# Add additional_inputs for the API key | |
additional_inputs=[ | |
gr.Textbox(label="API Key", type="password", placeholder="Enter your API key (starts with fs_)", interactive=True) | |
], | |
additional_inputs_accordion="API Key (Required)" | |
) | |
try: | |
demo.launch(debug=True, show_error=True) | |
except Exception as e: | |
print(f"Error launching Gradio interface: {e}") | |
print(traceback.format_exc()) | |
print("Please check the console output for more details on the error.") | |