Spaces:
Running
Running
# Standard libraries | |
import os | |
import re | |
import time | |
import json | |
import asyncio | |
import requests | |
import numpy as np | |
import pandas as pd | |
from pathlib import Path | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from functools import lru_cache | |
from typing import Optional, List | |
# External libraries | |
import gradio as gr | |
from sqlalchemy import create_engine, inspect, text | |
import psycopg2 | |
from groq import Groq | |
from dotenv import load_dotenv | |
#----------------------- | |
# Fetch variables | |
user = os.getenv("user") | |
password = os.getenv("password") | |
host = os.getenv("host") | |
port = os.getenv("port") | |
dbname = os.getenv("dbname") | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY_PAID") | |
GROQ_API_KEY_PAID = os.getenv("GROQ_API_KEY_PAID") | |
# Define path for cache file in the current directory | |
CACHE_FILE = Path("/tmp/sql_prompt_cache1.json") | |
#----------------------- | |
# create db connection | |
engine = create_engine(f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{dbname}") | |
# print(engine) | |
# ----------------------- | |
# Groq API settings | |
GROQ_EMBED_URL = "https://api.groq.com/openai/v1/embeddings" | |
GROQ_CHAT_URL = "https://api.groq.com/openai/v1/chat/completions" | |
EMBEDDING_MODEL = "llama3-405b-8192-embed" | |
# LLM_MODEL = "llama-3.3-70b-versatile" | |
LLM_MODEL = "llama3-70b-8192" | |
cMODEL = "llama-3.1-8b-instant" | |
MODEL = "llama-3.3-70b-versatile" | |
#----------------------- | |
# Configure headers for Groq API requests | |
GROQ_HEADERS = { | |
"Authorization": f"Bearer {GROQ_API_KEY}", | |
"Content-Type": "application/json"} | |
#----------------------- | |
# # version 1 (tables) | |
# def get_db_schema(engine): | |
# inspector = inspect(engine) | |
# schema_dict = {} | |
# for table in inspector.get_table_names(): | |
# columns = inspector.get_columns(table) | |
# column_names = [col['name'] for col in columns] | |
# schema_dict[table] = column_names | |
# return schema_dict | |
# schema = get_db_schema(engine) | |
schema = { | |
'rag_color_conversion_chart': ['brand_name','brand_color_code','brand_color','candle_shade_code','matching_type'], | |
'rag_order_detail': ['orderno','customername', 'phone', 'email', 'orderdate', 'producttype', 'colorcode', 'orderedqty', | |
'status','tracking_number','shippingaddress', 'shipping_type', 'estimateddeliverydate', | |
'delivery_date', 'total_amount'], | |
'rag_promo_code_data': ['promocode','username','firstname','lastname','companyname','phone','issuancedate','expirydate','user_id'], | |
'rag_color_conversion_charts_link': ['Conversion Chart Name', 'link'], | |
'rag_certificate_link': ['Certificates Name', 'link'] | |
} | |
# print(schema) | |
# ----------------------- | |
# # Bag of words from table | |
# def get_db_schema_words(engine): | |
# inspector = inspect(engine) | |
# words = [] | |
# for table in inspector.get_table_names(): | |
# words.extend(table.split("_")) # Split table names into words | |
# columns = inspector.get_columns(table) | |
# for col in columns: | |
# words.extend(col['name'].split("_")) # Split column names into words | |
# return list(set(words)) # Remove duplicates, if needed | |
# bag_of_words = get_db_schema_words(engine) | |
# bag_of_words = schema | |
bag_of_words = schema['rag_color_conversion_chart'],schema['rag_order_detail'],schema['rag_promo_code_data'],schema['rag_color_conversion_charts_link'],schema['rag_certificate_link'] | |
# ----------------------- | |
# Query formatter | |
def clean_sql_output(raw_content: str) -> str: | |
# Remove any markdown formatting and explanations | |
raw_content = raw_content.replace("```sql", "").replace("```", "").strip() | |
# Extract only the first valid SQL query using regex | |
queries = re.findall(r"(SELECT\s+.+?;)", raw_content, | |
flags=re.IGNORECASE | re.DOTALL) | |
return queries[0].strip() if queries else "-- Unable to extract valid SQL" | |
# ----------------------- | |
def query_groq( | |
user_prompt: str, | |
schema: Optional[str] = None, | |
system_prompt: str = "You are a PostgreSQL expert. Generate the best possible SQL queries for FILTERING based on the user's question. **FILTERING** must be done using **IN** or **OR**. Also If user asks for colors and matches and provides color codes, query for **brand_code** or **brand_color_code** by default. SQL queries must be design for **returning all attributes**. Return only the correct SQL query. Note: attributes datatypes are **text**.", | |
model: str = LLM_MODEL, | |
temperature: float = 0.3, | |
# max_tokens: int = 8192 | |
) -> str: | |
full_usr_prompt = f"User Question: {user_prompt}\n\nRefer to these target keywords for SQL Queries:{bag_of_words}" if bag_of_words else user_prompt # user | |
full_sys_prompt = f"{system_prompt}\n\nRefer to this Schema for SQL Queries:{schema}" if schema else system_prompt # system | |
response = requests.post( | |
GROQ_CHAT_URL, | |
headers=GROQ_HEADERS, | |
json={ | |
"model": model, | |
"messages": [ | |
{"role": "system", "content": full_sys_prompt}, | |
{"role": "user", "content": full_usr_prompt} | |
], | |
"temperature": temperature, | |
# "max_tokens": max_tokens | |
} | |
) | |
if response.status_code != 200: | |
raise Exception(f"❌ Error querying Groq:\n{response.text}") | |
# Clean output: remove triple backticks | |
content = response.json()["choices"][0]["message"]["content"].strip() | |
return clean_sql_output(content) | |
# ----------------------- | |
client = Groq(api_key=GROQ_API_KEY) | |
def correct_spelling_with_groq(text, context): | |
try: | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
# {"role": "system", "content": f"You are a helpful assistant that uses this context of words: {context}, to correct spellings in user queries, keeping their intent intact. Also if user asks for thread, he is refering to to find candle color code match. If user asks for colors search for brand colors"}, | |
{"role": "system", "content": f"You are a helpful assistant that uses this context of words: {context}, to correct spellings in user queries, keeping their intent intact. Also if user asks for thread, he is refering to to find candle_shade_code match. Also If user asks for colors and matches and provides color codes, search for brand code or brand color code by default"}, | |
{"role": "user", "content": f"Correct the following query: {text}"} | |
], | |
model=cMODEL, # or the correct model name like llama-3.3-70b-versatile | |
temperature=0.3, | |
max_completion_tokens=256, | |
top_p=1, | |
stream=False, | |
) | |
return chat_completion.choices[0].message.content.strip() | |
except Exception as e: | |
print("Groq correction failed:", e) | |
return text # fallback to original if Groq fails | |
# ----------------------- | |
# summarize = Groq(api_key=GROQ_API_KEY) | |
summarize = Groq(api_key=GROQ_API_KEY_PAID) | |
def summarize_with_groq(text, context): | |
if not text.strip(): | |
return "No content provided to summarize." | |
try: | |
response = summarize.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": f"""You are a helpful assistant that organizes and numerically lists and sub-lists closely related important text and shows and handles all unique variations and handles variations duplicates from the user inputs and its **context : {context}** in a professional manner and Always presents the output in a clean, professional format.. | |
[specially When the input contains order-related information, follow this following **strict format** without changes or additions: | |
--- | |
Order: 2677 | |
Customer Name: Brenda Cole | |
Customer Email: brendacole66@gmail.com | |
Phone Number: 12316130502 | |
Order Date: 2024-01-02 | |
Shipping Address: Title: Default, Address: 2606 N Lakeshore Dr Suit/Apt # , Zip Code: 49431 | |
Status: Delivered | |
Total Amount: 225.5 | |
--- | |
A. 3-5 Days Shipping: | |
- Tracking Link: https://wwwapps.ups.com/WebTracking/track?track=yes&trackNums=1ZB1F8280339024054 | |
a. Material: Polyester | |
1. Color Code: 7834 - Qty: 1, | |
2. Color Code: 61060 - Qty: 5, | |
3. Color Code: 60300 - Qty: 7, | |
4. Color Code: 8717 - Qty: 4, | |
--- | |
b. Material: Trial Order | |
1. Color Code: 7834 - Qty: 1, | |
2. Color Code: 6106 - Qty: 5, | |
--- | |
B. 10-15 Days Shipping: | |
- Tracking Link: Not Available | |
a. Material: Polyester | |
1. Color Code: 07834 - Qty: 1, | |
2. Color Code: 61060 - Qty: 5, | |
3. Color Code: 6030 - Qty: 7, | |
4. Color Code: 8717 - Qty: 4, | |
--- | |
b. Material: Trial Order | |
1. Color Code: 07834 - Qty: 1, | |
2. Color Code: 6106 - Qty: 5 | |
---] | |
**Important Guidelines:** | |
- Use **only** the information provided by the user — **do not** assume or add any details. | |
- **Do not** omit or miss any information. | |
- Use **Given format** for order related queries, and use normal formats for other queries. | |
- Enter all remaining items without stopping until complete. | |
- **Do not** use tables. | |
- **Do not** display processing information. | |
- **Do not** display dictionary. | |
- Always present the output in a **clean**, **professional** and **Customer friendly** format. | |
"""}, | |
# {"role": "user", "content": f"Please organize the following text and numerically lists and sub-lists closely related important text and show and handle all unique variations and handle duplicates variations professionally:\n\n{text} "} | |
{"role": "user", "content": f"{text}"} | |
], | |
model=MODEL, | |
temperature=0.25, | |
# max_completion_tokens=8192, | |
top_p=1, | |
stream=False, | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
print("Groq Summarization failed:", e) | |
return text | |
# ----------------------- | |
def load_cache(): | |
if os.path.exists(CACHE_FILE): | |
with open(CACHE_FILE, "r") as f: | |
return json.load(f) | |
return {} | |
def save_cache(cache: dict): | |
with open(CACHE_FILE, "w") as f: | |
json.dump(cache, f, indent=2) | |
# ----------------------- | |
def try_sql_parallel_with_disk_cache( | |
prompt: str, | |
engine, | |
schema: Optional[str] = None, | |
max_prompts: int = 6, | |
max_retries: int = 2, | |
verbose: bool = False | |
): | |
SQL_CACHE = load_cache() | |
def generate_variants(base_prompt: str, error: str, count: int) -> List[str]: | |
return [ | |
f"{base_prompt}\n\nTry variation #{i+1}.\nPrevious error:\n{error}" if error else f"{base_prompt}\n\nVariation #{i+1}" | |
for i in range(count) | |
] | |
def try_sql(sql: str): | |
try: | |
df = pd.read_sql(sql, engine) | |
return df if not df.empty else None | |
except Exception: | |
return None | |
attempts = 0 | |
error_message = "" | |
last_sql = "" | |
# ✅ Check cache first | |
if prompt in SQL_CACHE: | |
if verbose: | |
print("[Disk cache hit] Using cached SQL.") | |
sql = SQL_CACHE[prompt] | |
df = try_sql(sql) | |
if df is not None: | |
return {"success": True, "dataframe": df} | |
while attempts < max_retries: | |
attempts += 1 | |
prompt_variants = generate_variants(prompt, error_message, max_prompts) | |
with ThreadPoolExecutor(max_workers=max_prompts) as executor: | |
futures = {executor.submit(query_groq, p, schema=schema): p for p in prompt_variants} | |
new_sqls = [] | |
for future in as_completed(futures): | |
try: | |
new_sql = future.result() | |
new_sqls.append((futures[future], new_sql)) | |
except Exception: | |
continue | |
for p, sql in new_sqls: | |
last_sql = sql | |
if not sql or not sql.lower().strip().startswith("select"): | |
error_message = "-- Not a SELECT statement" | |
continue | |
df = try_sql(sql) | |
if df is not None: | |
# ✅ Cache only successful result | |
SQL_CACHE[prompt] = sql | |
save_cache(SQL_CACHE) | |
return {"success": True, "dataframe": df} | |
error_message = "-- Query failed or returned no rows" | |
return {"success": False, "error": error_message} | |
# ----------------------- | |
with gr.Blocks() as interface: | |
gr.Markdown("<h1 style='text-align: center;'>🕯️ CANDLES A.I CHAT SUPPORT 🕯️</h1>") | |
gr.Markdown("<h3 style='text-align: center;'>Please type your query below. For better results, include your full name or username, order no, brand color or candle color, user ID.</h3>") | |
chat_history = gr.State([]) | |
# Chatbot full-width, larger, scrollable, and copyable | |
chatbot = gr.Chatbot(label="Conversation", height=600, autoscroll=True, show_copy_button=True) | |
# Input and submit row | |
with gr.Row(equal_height=True): | |
user_input = gr.Textbox( | |
label="Your Question", | |
placeholder="e.g., John Smith, 1800, show matching candles etc.", | |
autofocus=True, | |
scale=4 | |
) | |
submit_btn = gr.Button("Submit", variant="primary", scale=1) | |
# Undo and clear row | |
with gr.Row(): | |
undo_btn = gr.Button("Undo Last") | |
clear_btn = gr.Button("Clear") | |
# Status feedback message | |
status_msg = gr.Markdown("") | |
# Main logic without SQL table | |
def handle_submit(user_input, history): | |
if not user_input.strip(): | |
return gr.update(), history, "⚠️ Please enter a valid query." | |
# correct_user_input = correct_spelling_with_groq(user_input , f"{bag_of_words}") | |
# history.append(("🧑 User", user_input)) | |
# sql_attempt = try_sql_parallel_with_disk_cache(user_input, engine, schema, max_prompts=7, max_retries=2) | |
sql_attempt = try_sql_parallel_with_disk_cache(user_input, engine, schema) | |
if sql_attempt["success"]: | |
# summary = summarize_with_groq(f"{correct_spelling_with_groq}\n\n{sql_attempt['dataframe'].to_dict()}" , f"{sql_attempt['dataframe'].columns.to_list()}") | |
summary = summarize_with_groq(f"{sql_attempt['dataframe'].to_dict()}" , f"{sql_attempt['dataframe'].columns.to_list()}") | |
answer = f"✅ Query executed successfully.\n\n{summary}" | |
status = "✅ Success" | |
else: | |
# summary = summarize_with_groq(f"User query: {user_input}", f"{correct_user_input}") | |
correct_user_input = correct_spelling_with_groq(user_input , f"{bag_of_words}") | |
answer = f"❌ Please try :\n\n{correct_user_input}" | |
status = "❌ Error: Invalid query or no matching data." | |
history.append(("🧑 User", user_input)) | |
history.append(("🤖 Assistant", answer)) | |
return "", history, status | |
# Button bindings | |
submit_btn.click( | |
handle_submit, | |
[user_input, chat_history], | |
[user_input, chat_history, status_msg] | |
).then( | |
lambda hist: hist, | |
[chat_history], | |
[chatbot] | |
) | |
user_input.submit( | |
handle_submit, | |
[user_input, chat_history], | |
[user_input, chat_history, status_msg] | |
).then( | |
lambda hist: hist, | |
[chat_history], | |
[chatbot] | |
) | |
undo_btn.click( | |
lambda history: history[:-2] if len(history) >= 2 else [], | |
[chat_history], | |
[chat_history] | |
).then( | |
lambda hist: hist, | |
[chat_history], | |
[chatbot] | |
) | |
clear_btn.click( | |
lambda: [], | |
None, | |
[chat_history] | |
).then( | |
lambda: ([], ""), | |
None, | |
[chatbot, status_msg] | |
) | |
interface.launch(ssr_mode=False,share=True) | |
# ----------------------- | |