Spaces:
Sleeping
Sleeping
# This is my app.py | |
import os | |
import torch | |
import re | |
import warnings | |
import time | |
import json | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig | |
from sentence_transformers import SentenceTransformer, util | |
import gspread | |
from google.auth import default | |
from tqdm import tqdm | |
from duckduckgo_search import DDGS | |
# Removed spacy and pathlib imports | |
import base64 | |
# Suppress warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# --- Configuration --- | |
SHEET_ID = "19ipxC2vHYhpXCefpxpIkpeYdI43a1Ku2kYwecgUULIw" # Your Google Sheet ID | |
HF_TOKEN = os.getenv("HF_TOKEN") # Get Hugging Face token from Space Secrets | |
GOOGLE_SERVICE_ACCOUNT_KEY_BASE64 = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY_BASE64") | |
# Changed model_id to Gemma 2B for CPU | |
# model_id = "google/gemma-2b" # Using Gemma 2B | |
model_id ="unsloth/gemma-3-1b-it" | |
# --- Constants for Prompting and Validation --- | |
SEARCH_MARKER = "ACTION: SEARCH:" | |
BUSINESS_LOOKUP_MARKER = "ACTION: LOOKUP_BUSINESS_INFO:" | |
ANSWER_DIRECTLY_MARKER = "ACTION: ANSWER_DIRECTLY:" | |
BUSINESS_LOOKUP_VALIDATION_THRESHOLD = 0.6 | |
SEARCH_VALIDATION_THRESHOLD = 0.6 | |
PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD = 0.5 | |
# --- Global variables to load once --- | |
tokenizer = None | |
model = None | |
# Removed nlp = None | |
embedder = None # Sentence Transformer | |
data = [] # Google Sheet data | |
descriptions = [] | |
embeddings = torch.tensor([]) # Google Sheet embeddings | |
# --- Loading Functions (Run once on startup) --- | |
# Removed load_spacy_model function | |
def load_sentence_transformer(): | |
"""Loads the Sentence Transformer model.""" | |
print("Loading Sentence Transformer...") | |
try: | |
embedder_model = SentenceTransformer("all-MiniLM-L6-v2") | |
print("Sentence Transformer loaded.") | |
return embedder_model | |
except Exception as e: | |
print(f"Error loading Sentence Transformer: {e}") | |
return None | |
# Inside app.py, locate this function | |
def load_google_sheet_data(sheet_id, service_account_key_base64): | |
"""Authenticates and loads data from Google Sheet.""" | |
print(f"Attempting to load Google Sheet data from ID: {sheet_id}") | |
if not service_account_key_base64: | |
print("Warning: GOOGLE_SERVICE_ACCOUNT_KEY_BASE64 secret is not set. Cannot access Google Sheets.") | |
return [], [], torch.tensor([]) | |
try: | |
print("Decoding base64 key...") | |
key_bytes = base64.b64decode(service_account_key_base64) | |
key_dict = json.loads(key_bytes) | |
print("Base64 key decoded and parsed.") | |
print("Authenticating with service account...") | |
from google.oauth2 import service_account | |
# --- Suggested Change: Add the Google Sheets Scope --- | |
# Define the scopes needed. This is the standard scope for Google Sheets. | |
scopes = ['https://www.googleapis.com/auth/spreadsheets.readonly'] # Use read-only if only reading, 'https://www.googleapis.com/auth/spreadsheets' for read/write | |
creds = service_account.Credentials.from_service_account_info(key_dict, scopes=scopes) | |
# --- End Suggested Change --- | |
client = gspread.authorize(creds) | |
print("Authentication successful.") | |
print(f"Opening sheet with key '{sheet_id}'...") | |
# *** IMPORTANT: If your sheet is NOT the first sheet, change 'sheet1' | |
# *** For example, if your sheet is named 'Data', use: | |
# sheet = client.open_by_key(sheet_id).worksheet("Data") | |
sheet = client.open_by_key(sheet_id).sheet1 | |
print(f"Successfully opened Google Sheet with ID: {sheet_id}") | |
print("Getting all records from the sheet...") | |
sheet_data = sheet.get_all_records() | |
print(f"Retrieved {len(sheet_data)} raw records from sheet.") | |
if not sheet_data: | |
print(f"Warning: No data records found in Google Sheet with ID: {sheet_id}") | |
return [], [], torch.tensor([]) | |
print("Filtering data for 'Service' and 'Description' columns...") | |
filtered_data = [row for row in sheet_data if row.get('Service') and row.get('Description')] | |
print(f"Filtered down to {len(filtered_data)} records.") | |
if not filtered_data: | |
print("Warning: Filtered data is empty after checking for 'Service' and 'Description'.") | |
# Check if headers exist at all if filtered_data is empty but sheet_data isn't | |
if sheet_data and ('Service' not in sheet_data[0] or 'Description' not in sheet_data[0]): | |
print("Error: 'Service' or 'Description' headers are missing or misspelled in the sheet.") | |
return [], [], torch.tensor([]) | |
# Re-checking column existence on filtered_data (redundant after filter but safe) | |
if 'Service' not in filtered_data[0] or 'Description' not in filtered_data[0]: | |
print("Error: Filtered Google Sheet data must contain 'Service' and 'Description' columns. This should not happen if filtering worked.") | |
return [], [], torch.tensor([]) | |
services = [row["Service"] for row in filtered_data] | |
descriptions = [row["Description"] for row in filtered_data] | |
print(f"Loaded {len(descriptions)} entries from Google Sheet for embedding.") | |
return filtered_data, descriptions, None # Return descriptions, embeddings encoded later | |
except gspread.exceptions.SpreadsheetNotFound: | |
print(f"Error: Google Sheet with ID '{sheet_id}' not found.") | |
print("Please check the SHEET_ID and ensure the service account has access.") | |
return [], [], torch.tensor([]) | |
except Exception as e: | |
print(f"An error occurred while accessing the Google Sheet: {e}") | |
return [], [], torch.tensor([]) | |
def load_llm_model(model_id, hf_token): | |
"""Loads the LLM in full precision (for CPU).""" | |
print(f"Loading model {model_id} in full precision...") | |
if not hf_token: | |
print("Error: HF_TOKEN secret is not set. Cannot load Hugging Face model.") | |
return None, None | |
try: | |
llm_tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
# Explicitly set the chat template for Gemma models | |
# This template formats messages as <start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n{response}<end_of_turn>\n | |
# and adds <bos> at the beginning and <start_of_turn>model\n at the end for generation prompt. | |
llm_tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<start_of_turn>user\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'system' %}{{ '<start_of_turn>system\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'tool' %}{{ '<start_of_turn>tool\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'model' %}{{ '<start_of_turn>model\n' + message['content'] + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}" | |
if llm_tokenizer.pad_token is None: | |
llm_tokenizer.pad_token = llm_tokenizer.eos_token | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_token, | |
device_map="auto", # This will likely map to 'cpu' | |
) | |
print(f"Model {model_id} loaded in full precision.") | |
return llm_model, llm_tokenizer | |
except Exception as e: | |
print(f"Error loading model {model_id}: {e}") | |
print("Please ensure transformers, trl, peft, and accelerate are installed.") | |
print("Check your Hugging Face token.") | |
return None, None | |
try: | |
llm_tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
if llm_tokenizer.pad_token is None: | |
llm_tokenizer.pad_token = llm_tokenizer.eos_token | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_token, | |
device_map="auto", # This will likely map to 'cpu' | |
) | |
print(f"Model {model_id} loaded in full precision.") | |
return llm_model, llm_tokenizer | |
except Exception as e: | |
print(f"Error loading model {model_id}: {e}") | |
print("Please ensure transformers, trl, peft, and accelerate are installed.") | |
print("Check your Hugging Face token.") | |
return None, None | |
# --- Load all assets on startup --- | |
print("Loading assets...") | |
# Removed nlp = load_spacy_model() # Keep this line commented out if you removed spaCy | |
embedder = load_sentence_transformer() | |
print(f"Embedder loaded: {embedder is not None}") # Add this print | |
data, descriptions, _ = load_google_sheet_data(SHEET_ID, GOOGLE_SERVICE_ACCOUNT_KEY_BASE64) | |
print(f"Google Sheet data loaded: {len(data)} rows") # Add this print | |
print(f"Google Sheet descriptions loaded: {len(descriptions)} items") # Add this print | |
if embedder and descriptions: | |
print("Encoding Google Sheet descriptions...") | |
try: | |
embeddings = embedder.encode(descriptions, convert_to_tensor=True) | |
print("Encoding complete.") | |
print(f"Embeddings shape: {embeddings.shape}") # Add this print | |
except Exception as e: | |
print(f"Error during embedding: {e}") | |
embeddings = torch.tensor([]) # Ensure embeddings is an empty tensor on error | |
else: | |
print("Skipping embedding due to missing embedder or descriptions.") | |
embeddings = torch.tensor([]) # Ensure embeddings is an empty tensor when skipped | |
print(f"Embeddings tensor after skip: {embeddings.shape}") # Should print torch.Size([]) | |
model, tokenizer = load_llm_model(model_id, HF_TOKEN) | |
print(f"LLM Model loaded: {model is not None}") # Add this print | |
print(f"LLM Tokenizer loaded: {tokenizer is not None}") # Add this print | |
# Check if essential components loaded | |
# This block provides a summary if anything failed during loading | |
if not model or not tokenizer or not embedder or embeddings is None or embeddings.numel() == 0 or not data: | |
print("\nERROR: Essential components failed to load. The application may not function correctly.") | |
if not model: print("- LLM Model failed to load.") | |
if not tokenizer: print("- LLM Tokenizer failed to load.") | |
if not embedder: print("- Sentence Embedder failed to load.") | |
# Check if embeddings is not None before accessing numel() | |
if embeddings is None or embeddings.numel() == 0: print("- Embeddings are empty or None.") | |
if not data: print("- Google Sheet Data is empty.") | |
# Descriptions being empty is implicitly covered by data being empty in this context | |
# if not descriptions: print("- Google Sheet Descriptions are empty.") | |
# Removed spaCy error message | |
# Continue, but the main inference function will need checks (already handled by the check at start of respond) | |
else: | |
print("\nAll essential components loaded successfully.") # Add this print | |
# Check if essential components loaded (Removed nlp from this check) | |
if not model or not tokenizer or not embedder: | |
print("\nERROR: Essential components failed to load. The application may not function correctly.") | |
if not model: print("- LLM Model failed to load.") | |
if not tokenizer: print("- LLM Tokenizer failed to load.") | |
if not embedder: print("- Sentence Embedder failed to load.") | |
# Removed spaCy error message | |
# Continue, but the main inference function will need checks | |
# --- Helper Functions --- | |
def perform_duckduckgo_search(query, max_results=3): | |
""" | |
Performs a search using DuckDuckGo and returns a list of dictionaries. | |
Includes a delay to avoid rate limits. | |
""" | |
search_results_list = [] | |
try: | |
time.sleep(1) | |
with DDGS() as ddgs: | |
for r in ddgs.text(query, max_results=max_results): | |
search_results_list.append(r) | |
except Exception as e: | |
print(f"Error during Duckduckgo search for '{query}': {e}") | |
return [] | |
return search_results_list | |
def retrieve_business_info(query, data, embeddings, embedder, threshold=0.50): | |
""" | |
Retrieves relevant business information based on query similarity. | |
Returns a dictionary if a match above threshold is found, otherwise None. | |
Also returns the similarity score. | |
Uses the global embedder, data, and embeddings. | |
""" | |
if not data or (embeddings is None or embeddings.numel() == 0) or embedder is None: | |
print("Skipping business info retrieval: Data, embeddings or embedder not available.") | |
return None, 0.0 | |
try: | |
user_embedding = embedder.encode(query, convert_to_tensor=True) | |
cos_scores = util.cos_sim(user_embedding, embeddings)[0] | |
best_score = cos_scores.max().item() | |
if best_score > threshold: | |
best_match_idx = cos_scores.argmax().item() | |
best_match = data[best_match_idx] | |
return best_match, best_score | |
else: | |
return None, best_score | |
except Exception as e: | |
print(f"Error during business information retrieval: {e}") | |
return None, 0.0 | |
# Alternative split_query function without spaCy | |
def split_query(query): | |
"""Splits a user query into potential sub-queries using regex.""" | |
# This regex splits on common separators like comma, semicolon, and conjunctions followed by interrogative words | |
parts = re.split(r',|;|\band\s+(?:who|what|where|when|why|how|is|are|can|tell me about)\b', query, flags=re.IGNORECASE) | |
# Filter out empty strings and strip whitespace | |
parts = [part.strip() for part in parts if part and part.strip()] | |
# If splitting didn't produce multiple meaningful parts, return the original query | |
if len(parts) <= 1: | |
return [query] | |
return parts | |
# --- Pass 1 System Prompt --- | |
pass1_instructions_action = """You are a helpful assistant for a business. Your primary goal in this first step is to analyze the user's query and decide which actions are needed to answer it. | |
You have analyzed the user's query and potentially broken it down into parts. For each part, a preliminary check was done to see if it matches known business information. The results of this check are provided below. | |
{business_check_summary} | |
Based on the user's query and the results of the business info check for each part, identify if you need to perform actions. | |
Output one or more actions, each on a new line, in the format: | |
ACTION: [ACTION_TYPE]: [Argument/Query for the action] | |
Possible actions: | |
1. **LOOKUP_BUSINESS_INFO**: If a part of the query asks about the business's services, prices, availability, or individuals mentioned in the business context, *and* the business info check for that part indicates a high relevance ({PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f} or higher). The argument should be the specific phrase or name to look up. | |
2. **SEARCH**: If a part of the query asks for current external information (e.g., current events, real-time data, general facts not in business info), *or* if a part that seems like it could be business info did *not* have a high relevance score in the preliminary check (below {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f}). The argument should be the precise search query. | |
3. **ANSWER_DIRECTLY**: If the overall query is a simple greeting or can be answered from your general knowledge without lookup or search, *and* the business info check results indicate low relevance for all parts. The argument should be the direct answer here. | |
**Crucially:** | |
- **Prioritize LOOKUP_BUSINESS_INFO** for any part of the query where the preliminary business info check score was {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f} or higher. | |
- Use **SEARCH** for parts about external information or where the business info check score was below {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f}. | |
- If a part of the query is clearly external (like asking about current events or famous people) even if its business info score wasn't zero, you should likely use SEARCH for it. | |
- Do NOT output any other text besides the ACTION lines. | |
- If the results suggest a direct answer is sufficient, use ANSWER_DIRECTLY. | |
Now, analyze the following user query, considering the business info check results provided above, and output the required actions: | |
""" | |
# --- Pass 2 System Prompt --- | |
pass2_instructions_synthesize = """You are a helpful assistant for a business. You have been provided with the original user query, relevant Business Information (if found), and results from external searches (if performed). | |
Your task is to synthesize ALL the provided information to answer the user's original question concisely and accurately. | |
**Prioritize Business Information** for details about the business, its services, or individuals mentioned within that context. | |
Use the Search Results for current external information that was requested. | |
If information for a specific part of the question was not found in either Business Information or Search Results, use your general knowledge if possible, or state that the information could not be found. | |
Synthesize the information into a natural language response. Do NOT copy and paste raw context or strings like 'Business Information:' or 'SEARCH RESULTS:' or 'ACTION:' or the raw user query. | |
After your answer, generate a few concise follow-up questions that a user might ask based on the previous turn's conversation and your response. List these questions clearly at the end of your response. | |
When search results were used to answer the question, list the URLs from the search results you used under a "Sources:" heading at the very end. | |
""" | |
# --- Main Inference Function for Gradio --- | |
def respond(user_input, chat_history): | |
""" | |
Processes user input, performs actions (lookup/search), and generates a response. | |
Manages chat history within Gradio state. | |
""" | |
# Check if models loaded successfully (Removed nlp from this check) | |
if model is None or tokenizer is None or embedder is None: | |
return "", chat_history + [(user_input, "Sorry, the application failed to load necessary components. Please try again later or contact the administrator.")] | |
original_user_input = user_input | |
# Initialize action results containers for this turn | |
search_results_dicts = [] | |
business_lookup_results_formatted = [] | |
response_pass1_raw = "" | |
# --- Pre-Pass 1: Programmatic Business Info Check for Query Parts --- | |
query_parts = split_query(original_user_input) # This now uses the regex split | |
business_check_results = [] | |
overall_pre_pass1_score = 0.0 | |
print("\n--- Processing new user query ---") | |
print(f"User: {user_input}") | |
print("Performing programmatic business info check on query parts...") | |
if query_parts: | |
for i, part in enumerate(query_parts): | |
match, score = retrieve_business_info(part, data, embeddings, embedder, threshold=0.0) | |
business_check_results.append({"part": part, "score": score, "match": match}) | |
print(f"- Part '{part}': Score {score:.4f}") | |
overall_pre_pass1_score = max(overall_pre_pass1_score, score) | |
else: | |
match, score = retrieve_business_info(original_user_input, data, embeddings, embedder, threshold=0.0) | |
business_check_results.append({"part": original_user_input, "score": score, "match": match}) | |
print(f"- Part '{original_user_input}': Score {score:.4f}") | |
overall_pre_pass1_score = score | |
is_likely_direct_answer = overall_pre_pass1_score < PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD and len(query_parts) <= 2 | |
# Format business check summary for Pass 1 prompt | |
business_check_summary = "Business Info Check Results for Query Parts:\n" | |
if business_check_results: | |
for result in business_check_results: | |
status = "High Relevance" if result['score'] >= PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD else "Low Relevance" | |
business_check_summary += f"- Part '{result['part']}': Score {result['score']:.4f} ({status})\n" | |
else: | |
business_check_summary += "- No parts identified or check skipped.\n" | |
business_check_summary += "\n" | |
# --- Pass 1: Action Identification (if not direct answer) --- | |
requested_actions = [] | |
answer_directly_provided = None | |
if is_likely_direct_answer: | |
print("Programmatically determined likely direct answer.") | |
response_pass1_raw = f"ACTION: ANSWER_DIRECTLY: " | |
else: | |
pass1_user_message_content = pass1_instructions_action.format( | |
business_check_summary=business_check_summary, | |
PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD=PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD | |
) + "\n\nUser Query: " + user_input | |
temp_chat_history_pass1 = [{"role": "user", "content": pass1_user_message_content}] | |
try: | |
prompt_pass1 = tokenizer.apply_chat_template( | |
temp_chat_history_pass1, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
generation_config_pass1 = GenerationConfig( | |
max_new_tokens=200, | |
do_sample=False, | |
temperature=0.1, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
use_cache=True | |
) | |
input_ids_pass1 = tokenizer(prompt_pass1, return_tensors="pt").input_ids | |
if model and input_ids_pass1.numel() > 0: | |
outputs_pass1 = model.generate( | |
input_ids=input_ids_pass1, | |
generation_config=generation_config_pass1, | |
) | |
prompt_length_pass1 = input_ids_pass1.shape[1] | |
if outputs_pass1.shape[1] > prompt_length_pass1: | |
generated_tokens_pass1 = outputs_pass1[0, prompt_length_pass1:] | |
response_pass1_raw = tokenizer.decode(generated_tokens_pass1, skip_special_tokens=True).strip() | |
else: | |
response_pass1_raw = "" | |
else: | |
response_pass1_raw = "" | |
except Exception as e: | |
print(f"Error during Pass 1 (Action Identification): {e}") | |
response_pass1_raw = f"ACTION: ANSWER_DIRECTLY: Error in Pass 1 - {e}" | |
# --- Parse Model's Requested Actions with Validation --- | |
if response_pass1_raw: | |
lines = response_pass1_raw.strip().split('\n') | |
for line in lines: | |
line = line.strip() | |
if line.startswith(SEARCH_MARKER): | |
query = line[len(SEARCH_MARKER):].strip() | |
if query: | |
_, score = retrieve_business_info(query, data, embeddings, embedder, threshold=0.0) | |
if score < SEARCH_VALIDATION_THRESHOLD: | |
requested_actions.append(("SEARCH", query)) | |
print(f"Validated Search Action for '{query}' (Score: {score:.4f})") | |
else: | |
print(f"Rejected Search Action for '{query}' (Score: {score:.4f}) - Too similar to business data.") | |
elif line.startswith(BUSINESS_LOOKUP_MARKER): | |
query = line[len(BUSINESS_LOOKUP_MARKER):].strip() | |
if query: | |
match, score = retrieve_business_info(query, data, embeddings, embedder, threshold=0.0) | |
if score > BUSINESS_LOOKUP_VALIDATION_THRESHOLD: | |
requested_actions.append(("LOOKUP_BUSINESS_INFO", query)) | |
print(f"Validated Business Lookup Action for '{query}' (Score: {score:.4f})") | |
else: | |
print(f"Rejected Business Lookup Action for '{query}' (Score: {score:.4f}) - Below validation threshold.") | |
elif line.startswith(ANSWER_DIRECTLY_MARKER): | |
answer = line[len(ANSWER_DIRECTLY_MARKER):].strip() | |
answer_directly_provided = answer if answer else original_user_input | |
requested_actions = [] | |
break | |
# --- Execute Actions (Search and Lookup) --- | |
context_for_pass2 = "" | |
if requested_actions: | |
print("Executing requested actions...") | |
for action_type, query in requested_actions: | |
if action_type == "SEARCH": | |
print(f"Performing search for: '{query}'") | |
results = perform_duckduckgo_search(query) | |
if results: | |
search_results_dicts.extend(results) | |
print(f"Found {len(results)} search results.") | |
else: | |
print(f"No search results found for '{query}'.") | |
elif action_type == "LOOKUP_BUSINESS_INFO": | |
print(f"Performing business info lookup for: '{query}'") | |
match, score = retrieve_business_info(query, data, embeddings, embedder, threshold=retrieve_business_info.__defaults__[0]) | |
print(f"Actual lookup score for '{query}': {score:.4f} (Threshold: {retrieve_business_info.__defaults__[0]})") | |
if match: | |
formatted_match = f"""Service: {match.get('Service', 'N/A')} | |
Description: {match.get('Description', 'N/A')} | |
Price: {match.get('Price', 'N/A')} | |
Available: {match.get('Available', 'N/A')}""" | |
business_lookup_results_formatted.append(formatted_match) | |
print(f"Found business info match.") | |
else: | |
print(f"No business info match found for '{query}' at threshold {retrieve_business_info.__defaults__[0]}.") | |
# --- Prepare Context for Pass 2 based on executed actions --- | |
if business_lookup_results_formatted: | |
context_for_pass2 += "Business Information (Use this for questions about the business):\n" | |
context_for_pass2 += "\n---\n".join(business_lookup_results_formatted) | |
context_for_pass2 += "\n\n" | |
if search_results_dicts: | |
context_for_pass2 += "SEARCH RESULTS (Use this for current external information):\n" | |
aggregated_search_results_formatted = [] | |
for result in search_results_dicts: | |
aggregated_search_results_formatted.append(f"Title: {result.get('title', 'N/A')}\nSnippet: {result.get('body', 'N/A')}\nURL: {result.get('href', 'N/A')}") | |
context_for_pass2 += "\n---\n".join(aggregated_search_results_formatted) + "\n\n" | |
if requested_actions and not business_lookup_results_formatted and not search_results_dicts: | |
context_for_pass2 = "Note: No relevant information was found in Business Information or via Search for your query." | |
print("Note: No results were found for the requested actions.") | |
# If ANSWER_DIRECTLY was determined | |
if answer_directly_provided is not None: | |
print(f"Handling as direct answer: {answer_directly_provided}") | |
context_for_pass2 = "Note: This query is a simple request or greeting." | |
if answer_directly_provided != original_user_input and answer_directly_provided != "": | |
context_for_pass2 += f" Initial suggestion from action step: {answer_directly_provided}" | |
search_results_dicts = [] | |
business_lookup_results_formatted = [] | |
# If no actions or direct answer, and no results | |
if not requested_actions and answer_directly_provided is None: | |
if response_pass1_raw.strip(): | |
print("Warning: Pass 1 did not result in valid actions or a direct answer.") | |
context_for_pass2 = f"Error: Could not determine actions from Pass 1 response: '{response_pass1_raw}'." | |
else: | |
print("Warning: Pass 1 generated an empty response.") | |
context_for_pass2 = "Error: Pass 1 generated an empty response." | |
# --- Pass 2: Synthesize and Respond --- | |
final_response = "Sorry, I couldn't generate a response." | |
if model is not None and tokenizer is not None: | |
pass2_user_message_content = pass2_instructions_synthesize + "\n\nOriginal User Query: " + original_user_input + "\n\n" + context_for_pass2 | |
model_chat_history = [] | |
for user_msg, bot_msg in chat_history: | |
model_chat_history.append({"role": "user", "content": user_msg}) | |
model_chat_history.append({"role": "assistant", "content": bot_msg}) | |
model_chat_history.append({"role": "user", "content": pass2_user_message_content}) | |
try: | |
prompt_pass2 = tokenizer.apply_chat_template( | |
model_chat_history, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
generation_config_pass2 = GenerationConfig( | |
max_new_tokens=1500, | |
do_sample=True, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
repetition_penalty=1.1, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
use_cache=True | |
) | |
input_ids_pass2 = tokenizer(prompt_pass2, return_tensors="pt").input_ids | |
if model and input_ids_pass2.numel() > 0: | |
outputs_pass2 = model.generate( | |
input_ids=input_ids_pass2, | |
generation_config=generation_config_pass2, | |
) | |
prompt_length_pass2 = input_ids_pass2.shape[1] | |
if outputs_pass2.shape[1] > prompt_length_pass2: | |
generated_tokens_pass2 = outputs_pass2[0, prompt_length_pass2:] | |
final_response = tokenizer.decode(generated_tokens_pass2, skip_special_tokens=True).strip() | |
else: | |
final_response = "..." | |
else: | |
final_response = "Error: Model or empty input for Pass 2." | |
except Exception as gen_error: | |
print(f"Error during model generation in Pass 2: {gen_error}") | |
final_response = "Error generating response in Pass 2." | |
# --- Post-process Final Response from Pass 2 --- | |
cleaned_response = final_response | |
lines = cleaned_response.split('\n') | |
cleaned_lines = [line for line in lines if not line.strip().lower().startswith("business information") | |
and not line.strip().lower().startswith("search results") | |
and not line.strip().startswith("---") | |
and not line.strip().lower().startswith("original user query:") | |
and not line.strip().lower().startswith("you are a helpful assistant for a business.")] | |
cleaned_response = "\n".join(cleaned_lines).strip() | |
urls_to_list = [result.get('href') for result in search_results_dicts if result.get('href')] | |
urls_to_list = list(dict.fromkeys(urls_to_list)) | |
if search_results_dicts and urls_to_list: | |
cleaned_response += "\n\nSources:\n" + "\n".join(urls_to_list) | |
final_response = cleaned_response | |
if not final_response.strip(): | |
final_response = "Sorry, I couldn't generate a meaningful response based on the information found." | |
print("Warning: Final response was empty after cleaning.") | |
else: | |
final_response = "Sorry, the core language model is not available." | |
print("Error: LLM model or tokenizer not loaded for Pass 2.") | |
# --- Update Chat History for Gradio --- | |
updated_chat_history = chat_history + [(original_user_input, final_response)] | |
max_history_pairs = 10 | |
if len(updated_chat_history) > max_history_pairs: | |
updated_chat_history = updated_chat_history[-max_history_pairs:] | |
return "", updated_chat_history |