|
from services.ai import rank_tables,generate_sql,generate_answer , correct_sql , evaluate_difficulty |
|
from services.utils import filter_tables |
|
from openai import OpenAI |
|
from database import Database |
|
from filesource import FileSource |
|
import os |
|
MAX_TABLE = 3 |
|
|
|
client = OpenAI( |
|
base_url=os.getenv("LLM_ENDPOINT"), |
|
api_key=os.getenv("LLM_KEY") |
|
) |
|
|
|
def run_agent(database,prompt,give_answer=True): |
|
|
|
retry = 5 |
|
tables = database.get_tables_array() |
|
|
|
use_thinking = False |
|
|
|
|
|
if len(tables) > MAX_TABLE: |
|
print(f"using reranking because number of tables is greater than {MAX_TABLE}") |
|
ranked = rank_tables(prompt,tables) |
|
tables = filter_tables(0,ranked)[:MAX_TABLE] |
|
|
|
dif = int(evaluate_difficulty(client,prompt)) |
|
if dif > 7: |
|
print("difficulty is > 7 so we enable thinking mode") |
|
use_thinking = True |
|
sql = generate_sql(client,prompt,tables,use_thinking) |
|
nb_try = 0 |
|
success = False |
|
while nb_try < retry and not success: |
|
nb_try = nb_try + 1 |
|
try: |
|
print("try to launch sql request") |
|
result = database.query(sql) |
|
success = True |
|
except Exception as e: |
|
print(f"Error : {e}") |
|
print("Try to self correct...") |
|
error = f"{type(e).__name__} - {str(e)}" |
|
if nb_try < retry - 2: |
|
sql = correct_sql(client,prompt,sql,tables,error,True) |
|
else: |
|
sql = correct_sql(client,prompt,sql,tables,error,False) |
|
|
|
print(sql) |
|
|
|
if success: |
|
print(result.to_markdown()) |
|
if give_answer: |
|
return generate_answer(client,sql,prompt,result.to_markdown(),use_thinking) |
|
else: |
|
return f"Generated sql query : {sql}\n Query Result : \n {result.to_markdown()}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|