|
from llama_index.core.query_pipeline import ( |
|
QueryPipeline, |
|
Link, |
|
InputComponent, |
|
CustomQueryComponent, |
|
) |
|
|
|
from llama_index.core.objects import ( |
|
SQLTableNodeMapping, |
|
ObjectIndex, |
|
SQLTableSchema, |
|
) |
|
from pyvis.network import Network |
|
import Stemmer |
|
from IPython.display import display, HTML |
|
from sqlalchemy import create_engine |
|
from llama_index.core import SQLDatabase, VectorStoreIndex, PromptTemplate |
|
from llama_index.core.program import LLMTextCompletionProgram |
|
from llama_index.core.bridge.pydantic import BaseModel, Field |
|
from typing import Dict, List, Any |
|
from llama_index.core.query_pipeline import FnComponent |
|
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT |
|
from llama_index.core.retrievers import SQLRetriever |
|
from llama_index.llms.ollama import Ollama |
|
from llama_index.core.objects.base import ObjectRetriever |
|
import pymysql, pandas as pd |
|
from llama_index.retrievers.bm25 import BM25Retriever |
|
from llama_index.core.schema import IndexNode |
|
from modules import ( |
|
get_table_obj_retriever, |
|
create_table_obj_retriever, |
|
get_table_context_str, |
|
parse_response_to_sql, |
|
CustomSQLRetriever |
|
) |
|
|
|
|
|
|
|
PROMPT_STR = """\ |
|
Give me a summary of the table with the following format. |
|
|
|
- table_summary: Describe what the table is about in short. Columns: [col1(type), col2(type), ...] |
|
|
|
Table: |
|
{table_str} |
|
""" |
|
|
|
db_user = "shenzhen_ai_for_vibemate_eson" |
|
db_password = "dBsnc7OrM0MVi0FEhiHe2y" |
|
db_host = "192.168.1.99" |
|
db_port = 3306 |
|
db_name = "hytto_surfease" |
|
|
|
TABLE_SUMMARY = { |
|
"t_sur_media_sync_es": "This table is about Porn video information:\n\nt_sur_media_sync_es: Columns:id (integer), web_url (string), duration (integer), pattern_per (integer), like_count (integer), dislike_count (integer), view_count (integer), cover_picture (string), title (string), upload_date (datetime), uploader (string), create_time (datetime), update_time (datetime), categories (list of strings), abbreviate_video_url (string), abbreviate_mp4_video_url (string), resource_type (string), like_count_show (integer), stat_version (integer), tags (list of strings), model_name (string), publisher_type (string), period (integer), sexual_preference (string), country (string), type (string), rank_number (integer), rank_rate (float), has_pattern (boolean), trace (string), manifest_url (string), is_delete (boolean), web_url_md5 (string), view_key (string)", |
|
"t_sur_models_info": "This table is about Stripchat models' information:\n\nt_sur_models_info: Columns:id (INTEGER), username (VARCHAR(100), image (VARCHAR(500), num_users (INTEGER), pf (VARCHAR(50), pf_model_unite (VARCHAR(50), use_plugin (INTEGER), create_time (DATETIME), update_time (DATETIME), update_time (DATETIME), gender (VARCHAR(50), broadcast_type (VARCHAR(50), common_gender (VARCHAR(50), avatar (VARCHAR(512), age (INTEGER) " |
|
} |
|
|
|
class SQLPipeline: |
|
def __init__(self, llm: Ollama): |
|
self.llm = llm |
|
self.engine = create_engine(f"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}") |
|
self.sql_db = SQLDatabase(self.engine) |
|
self.table_names = self.sql_db.get_usable_table_names() |
|
self.schema_table_mapping = {} |
|
self.init_schema_table_mapping() |
|
self.modules = self.prepare_modules() |
|
self.pipeline = self.build_pipeline() |
|
|
|
def init_schema_table_mapping(self): |
|
self.table_infos = [] |
|
table_names = set() |
|
for table in self.table_names: |
|
table_info = TableInfo(table_name=table, table_summary=TABLE_SUMMARY[table]) |
|
self.table_infos.append(table_info) |
|
|
|
self.schema_table_mapping[table_info.table_name] = table |
|
|
|
def prepare_modules(self): |
|
modules = {} |
|
|
|
modules["input"] = InputComponent() |
|
|
|
table_obj_index_path = "/home/purui/projects/chatbot/kb/sql/table_obj_index" |
|
retriever = create_table_obj_retriever( |
|
index_path=table_obj_index_path, |
|
table_infos=self.table_infos, |
|
sql_db=self.sql_db, |
|
schema_table_mapping=self.schema_table_mapping |
|
) |
|
modules["table_retriever"] = TableRetrieveComponent( |
|
retriever=retriever, |
|
sql_database=self.sql_db |
|
) |
|
|
|
text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format( |
|
dialect=self.engine.dialect.name |
|
) |
|
modules["text2sql_prompt"] = text2sql_prompt |
|
|
|
modules["text2sql_llm"] = self.llm |
|
|
|
modules["sql_output_parser"] = FnComponent(fn=parse_response_to_sql) |
|
|
|
|
|
modules["sql_retriever"] = CustomSQLRetriever(sql_db=self.sql_db) |
|
|
|
response_synthesis_prompt_str = ( |
|
"Given an input question, synthesize a response from the query results.\n" |
|
"Query: {query_str}\n" |
|
"SQL: {sql_query}\n" |
|
"SQL Response: {context_str}\n" |
|
"Response: " |
|
) |
|
response_synthesis_prompt = PromptTemplate( |
|
response_synthesis_prompt_str, |
|
) |
|
modules["response_synthesis_prompt"] = response_synthesis_prompt |
|
|
|
modules["response_synthesis_llm"] = self.llm |
|
|
|
return modules |
|
|
|
def build_pipeline(self): |
|
qp = QueryPipeline( |
|
modules=self.modules, |
|
verbose=True, |
|
) |
|
|
|
qp.add_link("input", "table_retriever", dest_key="query") |
|
qp.add_link("input", "text2sql_prompt", dest_key="query_str") |
|
qp.add_link("table_retriever", "text2sql_prompt", dest_key="schema") |
|
qp.add_chain( |
|
["text2sql_prompt", "text2sql_llm", "sql_output_parser"] |
|
) |
|
qp.add_link( |
|
"sql_output_parser", "response_synthesis_prompt", dest_key="sql_query" |
|
) |
|
qp.add_link("input", "sql_retriever", dest_key="query_str") |
|
qp.add_link("sql_output_parser", "sql_retriever", dest_key="sql_query") |
|
|
|
|
|
qp.add_link( |
|
"sql_retriever", "response_synthesis_prompt", dest_key="context_str", condition_fn=lambda x: x["is_valid"] |
|
) |
|
qp.add_link("sql_retriever", "text2sql_prompt", src_key="query_str", dest_key="query_str", condition_fn=lambda x: not x["is_valid"]) |
|
qp.add_link("input", "response_synthesis_prompt", dest_key="query_str") |
|
qp.add_link("response_synthesis_prompt", "response_synthesis_llm") |
|
return qp |
|
|
|
def get_vision(self): |
|
net = Network(notebook=True, cdn_resources="in_line", directed=True) |
|
net.from_nx(self.pipeline.dag) |
|
net.write_html("text2sql_dag.html") |
|
with open("text2sql_dag.html", "r") as file: |
|
html_content = file.read() |
|
|
|
display(HTML(html_content)) |
|
def run(self, query: str): |
|
response = self.pipeline.run(query=query) |
|
return str(response) |
|
|
|
|
|
class TableInfo(BaseModel): |
|
"""Information regarding a structured table.""" |
|
|
|
table_name: str = Field( |
|
..., description="table name (must be underscores and NO spaces)" |
|
) |
|
table_summary: str = Field( |
|
..., description="short, concise summary/caption of the table" |
|
) |
|
|
|
|
|
class TableRetrieveComponent(CustomQueryComponent): |
|
"""Retrieves table information from the database.""" |
|
|
|
retriever: ObjectRetriever = Field(..., description="Retriever to use for table info") |
|
sql_database: SQLDatabase = Field(..., description="SQL engine to use for table info") |
|
|
|
def _validate_component_inputs( |
|
self, input: Dict[str, Any] |
|
) -> Dict[str, Any]: |
|
"""Validate component inputs during run_component.""" |
|
|
|
return input |
|
|
|
@property |
|
def _input_keys(self) -> set: |
|
"""Input keys dict.""" |
|
return {"query"} |
|
|
|
@property |
|
def _output_keys(self) -> set: |
|
|
|
return {"output"} |
|
|
|
def _run_component(self, **kwargs) -> Dict[str, Any]: |
|
"""Run the component.""" |
|
|
|
table_schema = self.retriever.retrieve(kwargs["query"])[0] |
|
table_name = table_schema.table_name |
|
table_info = TABLE_SUMMARY[table_name] |
|
return {"output": table_info} |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
sql_pipeline = SQLPipeline(llm=Ollama(model="mannix/llama3.1-8b-abliterated", |
|
request_timeout=120)) |
|
response = sql_pipeline.run("I want 5 different big tits milf porn with it's title and web url") |
|
print(response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|