File size: 4,826 Bytes
24371db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.generators.openai import OpenAIGenerator
from haystack.components.routers import ConditionalRouter
from functions import SQLiteQuery
from typing import List
import sqlite3
import os
from getpass import getpass
from dotenv import load_dotenv
load_dotenv()
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
'''
prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
The query is to be answered for the table is called 'data_source' with the following
Columns: {{columns}};
Answer:""")
sql_query = SQLQuery('data_source.db')
llm = OpenAIGenerator(model="gpt-4")
sql_pipeline = Pipeline()
sql_pipeline.add_component("prompt", prompt)
sql_pipeline.add_component("llm", llm)
sql_pipeline.add_component("sql_querier", sql_query)
sql_pipeline.connect("prompt", "llm")
sql_pipeline.connect("llm.replies", "sql_querier.queries")
# If you want to draw the pipeline, uncomment below π
sql_pipeline.show()
print("PIPELINE RUNNING")
result = sql_pipeline.run({"prompt": {"question": "On which days of the week are average sales highest?",
"columns": columns}})
print(result["sql_querier"]["results"][0])
'''
from haystack.components.builders import PromptBuilder
from haystack.components.generators import OpenAIGenerator
llm = OpenAIGenerator(model="gpt-4o")
sql_query = SQLiteQuery('data_source.db')
connection = sqlite3.connect('data_source.db')
cur=connection.execute('select * from data_source')
columns = [i[0] for i in cur.description]
print("COLUMNS 2")
print(columns)
cur.close()
#Rag Pipeline
prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
If the question cannot be answered given the provided table and columns, return 'no_answer'
The query is to be answered for the table is called 'data_source' with the following
Columns: {{columns}};
Answer:""")
routes = [
{
"condition": "{{'no_answer' not in replies[0]}}",
"output": "{{replies}}",
"output_name": "sql",
"output_type": List[str],
},
{
"condition": "{{'no_answer' in replies[0]}}",
"output": "{{question}}",
"output_name": "go_to_fallback",
"output_type": str,
},
]
router = ConditionalRouter(routes)
fallback_prompt = PromptBuilder(template="""User entered a query that cannot be answered with the given table.
The query was: {{question}} and the table had columns: {{columns}}.
Let the user know why the question cannot be answered""")
fallback_llm = OpenAIGenerator(model="gpt-4")
conditional_sql_pipeline = Pipeline()
conditional_sql_pipeline.add_component("prompt", prompt)
conditional_sql_pipeline.add_component("llm", llm)
conditional_sql_pipeline.add_component("router", router)
conditional_sql_pipeline.add_component("fallback_prompt", fallback_prompt)
conditional_sql_pipeline.add_component("fallback_llm", fallback_llm)
conditional_sql_pipeline.add_component("sql_querier", sql_query)
conditional_sql_pipeline.connect("prompt", "llm")
conditional_sql_pipeline.connect("llm.replies", "router.replies")
conditional_sql_pipeline.connect("router.sql", "sql_querier.queries")
conditional_sql_pipeline.connect("router.go_to_fallback", "fallback_prompt.question")
conditional_sql_pipeline.connect("fallback_prompt", "fallback_llm")
question = "When is my birthday?"
result = conditional_sql_pipeline.run({"prompt": {"question": question,
"columns": columns},
"router": {"question": question},
"fallback_prompt": {"columns": columns}})
def rag_pipeline_func(question: str, columns: str):
result = conditional_sql_pipeline.run({"prompt": {"question": question,
"columns": columns},
"router": {"question": question},
"fallback_prompt": {"columns": columns}})
if 'sql_querier' in result:
reply = result['sql_querier']['results'][0]
elif 'fallback_llm' in result:
reply = result['fallback_llm']['replies'][0]
else:
reply = result["llm"]["replies"][0]
print("reply content")
print(reply.content)
return {"reply": reply.content} |