SexBot / pipeline /sql_pipeline.py
Pew404's picture
Upload folder using huggingface_hub
318db6e verified
from llama_index.core.query_pipeline import (
QueryPipeline,
Link,
InputComponent,
CustomQueryComponent,
)
from llama_index.core.objects import (
SQLTableNodeMapping,
ObjectIndex,
SQLTableSchema,
)
from pyvis.network import Network
import Stemmer
from IPython.display import display, HTML
from sqlalchemy import create_engine
from llama_index.core import SQLDatabase, VectorStoreIndex, PromptTemplate
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.bridge.pydantic import BaseModel, Field
from typing import Dict, List, Any
from llama_index.core.query_pipeline import FnComponent
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core.retrievers import SQLRetriever
from llama_index.llms.ollama import Ollama
from llama_index.core.objects.base import ObjectRetriever
import pymysql, pandas as pd
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.schema import IndexNode
from modules import (
get_table_obj_retriever,
create_table_obj_retriever,
get_table_context_str,
parse_response_to_sql,
CustomSQLRetriever
)
PROMPT_STR = """\
Give me a summary of the table with the following format.
- table_summary: Describe what the table is about in short. Columns: [col1(type), col2(type), ...]
Table:
{table_str}
"""
db_user = "shenzhen_ai_for_vibemate_eson"
db_password = "dBsnc7OrM0MVi0FEhiHe2y"
db_host = "192.168.1.99"
db_port = 3306
db_name = "hytto_surfease"
TABLE_SUMMARY = {
"t_sur_media_sync_es": "This table is about Porn video information:\n\nt_sur_media_sync_es: Columns:id (integer), web_url (string), duration (integer), pattern_per (integer), like_count (integer), dislike_count (integer), view_count (integer), cover_picture (string), title (string), upload_date (datetime), uploader (string), create_time (datetime), update_time (datetime), categories (list of strings), abbreviate_video_url (string), abbreviate_mp4_video_url (string), resource_type (string), like_count_show (integer), stat_version (integer), tags (list of strings), model_name (string), publisher_type (string), period (integer), sexual_preference (string), country (string), type (string), rank_number (integer), rank_rate (float), has_pattern (boolean), trace (string), manifest_url (string), is_delete (boolean), web_url_md5 (string), view_key (string)",
"t_sur_models_info": "This table is about Stripchat models' information:\n\nt_sur_models_info: Columns:id (INTEGER), username (VARCHAR(100), image (VARCHAR(500), num_users (INTEGER), pf (VARCHAR(50), pf_model_unite (VARCHAR(50), use_plugin (INTEGER), create_time (DATETIME), update_time (DATETIME), update_time (DATETIME), gender (VARCHAR(50), broadcast_type (VARCHAR(50), common_gender (VARCHAR(50), avatar (VARCHAR(512), age (INTEGER) "
}
class SQLPipeline:
def __init__(self, llm: Ollama):
self.llm = llm
self.engine = create_engine(f"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}")
self.sql_db = SQLDatabase(self.engine)
self.table_names = self.sql_db.get_usable_table_names()
self.schema_table_mapping = {}
self.init_schema_table_mapping()
self.modules = self.prepare_modules()
self.pipeline = self.build_pipeline()
def init_schema_table_mapping(self):
self.table_infos = []
table_names = set()
for table in self.table_names:
table_info = TableInfo(table_name=table, table_summary=TABLE_SUMMARY[table])
self.table_infos.append(table_info)
# 摘要表名: 真实表名
self.schema_table_mapping[table_info.table_name] = table
def prepare_modules(self):
modules = {}
# input
modules["input"] = InputComponent()
# table retriever
table_obj_index_path = "/home/purui/projects/chatbot/kb/sql/table_obj_index"
retriever = create_table_obj_retriever(
index_path=table_obj_index_path,
table_infos=self.table_infos,
sql_db=self.sql_db,
schema_table_mapping=self.schema_table_mapping
)
modules["table_retriever"] = TableRetrieveComponent(
retriever=retriever,
sql_database=self.sql_db
)
# text2sql_prompt
text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
dialect=self.engine.dialect.name
)
modules["text2sql_prompt"] = text2sql_prompt
# text2sql_llm
modules["text2sql_llm"] = self.llm
# sql output parser
modules["sql_output_parser"] = FnComponent(fn=parse_response_to_sql)
# sql retriever
# modules["sql_retriever"] = SQLRetriever(self.sql_db)
modules["sql_retriever"] = CustomSQLRetriever(sql_db=self.sql_db)
# response synthesise prompt
response_synthesis_prompt_str = (
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
response_synthesis_prompt = PromptTemplate(
response_synthesis_prompt_str,
)
modules["response_synthesis_prompt"] = response_synthesis_prompt
# response synthesise llm
modules["response_synthesis_llm"] = self.llm
return modules
def build_pipeline(self):
qp = QueryPipeline(
modules=self.modules,
verbose=True,
)
# add chains & links
qp.add_link("input", "table_retriever", dest_key="query")
qp.add_link("input", "text2sql_prompt", dest_key="query_str")
qp.add_link("table_retriever", "text2sql_prompt", dest_key="schema")
qp.add_chain(
["text2sql_prompt", "text2sql_llm", "sql_output_parser"]
)
qp.add_link(
"sql_output_parser", "response_synthesis_prompt", dest_key="sql_query"
)
qp.add_link("input", "sql_retriever", dest_key="query_str")
qp.add_link("sql_output_parser", "sql_retriever", dest_key="sql_query")
# custom sql_retriever component:定义is_valid字段,如果执行sql检索有正确返回结果,则is_valid为True 作为sql_retriever -> response_synthesis_prompt的链接条件
# 若is_valid为False,则重新回到text2sql_prompt链路中,重新生成sql
qp.add_link(
"sql_retriever", "response_synthesis_prompt", dest_key="context_str", condition_fn=lambda x: x["is_valid"]
)
qp.add_link("sql_retriever", "text2sql_prompt", src_key="query_str", dest_key="query_str", condition_fn=lambda x: not x["is_valid"])
qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
qp.add_link("response_synthesis_prompt", "response_synthesis_llm")
return qp
def get_vision(self):
net = Network(notebook=True, cdn_resources="in_line", directed=True)
net.from_nx(self.pipeline.dag)
net.write_html("text2sql_dag.html")
with open("text2sql_dag.html", "r") as file:
html_content = file.read()
# Display the HTML content
display(HTML(html_content))
def run(self, query: str):
response = self.pipeline.run(query=query)
return str(response)
class TableInfo(BaseModel):
"""Information regarding a structured table."""
table_name: str = Field(
..., description="table name (must be underscores and NO spaces)"
)
table_summary: str = Field(
..., description="short, concise summary/caption of the table"
)
class TableRetrieveComponent(CustomQueryComponent):
"""Retrieves table information from the database."""
retriever: ObjectRetriever = Field(..., description="Retriever to use for table info")
sql_database: SQLDatabase = Field(..., description="SQL engine to use for table info")
def _validate_component_inputs(
self, input: Dict[str, Any]
) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
# NOTE: this is OPTIONAL but we show you here how to do validation as an example
return input
@property
def _input_keys(self) -> set:
"""Input keys dict."""
return {"query"}
@property
def _output_keys(self) -> set:
# can do multi-outputs too
return {"output"}
def _run_component(self, **kwargs) -> Dict[str, Any]:
"""Run the component."""
# run logic
table_schema = self.retriever.retrieve(kwargs["query"])[0]
table_name = table_schema.table_name
table_info = TABLE_SUMMARY[table_name]
return {"output": table_info}
if __name__ == '__main__':
sql_pipeline = SQLPipeline(llm=Ollama(model="mannix/llama3.1-8b-abliterated",
request_timeout=120))
response = sql_pipeline.run("I want 5 different big tits milf porn with it's title and web url")
print(response)
# table_retriever = sql_pipeline.modules["table_retriever"]
# # result = table_retriever.retrieve("Give me top 5 videos by view count.")
# # print(result)
# qp = QueryPipeline(
# modules={
# "input": InputComponent(),
# "table_retriever": TableRetrieveComponent(retriever=table_retriever, sql_database=sql_pipeline.sql_db),
# }
# )
# qp.add_link("input", "table_retriever", dest_key="query")
# response = qp.run(query="Give me top 5 videos by view count.")
# print(response)