import os import chainlit as cl import llama_index from llama_index import set_global_handler from llama_index.embeddings import OpenAIEmbedding from llama_index import ServiceContext from llama_index.llms import OpenAI from llama_index import SimpleDirectoryReader from llama_index.ingestion import IngestionPipeline from llama_index.node_parser import TokenTextSplitter from llama_index import load_index_from_storage from llama_index.tools import FunctionTool from llama_index.vector_stores.types import ( VectorStoreInfo, MetadataInfo, ExactMatchFilter, MetadataFilters, ) from llama_index.retrievers import VectorIndexRetriever from llama_index.query_engine import RetrieverQueryEngine from typing import List from AutoRetrieveModel import AutoRetrieveModel from llama_index.agent import OpenAIAgent from sqlalchemy import create_engine from llama_index import SQLDatabase from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine from llama_index.tools.query_engine import QueryEngineTool import pandas as pd import openai set_global_handler("wandb", run_args={"project": "llamaindex-demo-v1"}) wandb_callback = llama_index.global_handler def create_semantic_agent(service_context): # Load in wikipedia index storage_context = wandb_callback.load_storage_context( artifact_url="jfreeman/llamaindex-demo-v1/wiki-index:v1" ) index = load_index_from_storage(storage_context, service_context=service_context) def auto_retrieve_fn( query: str, filter_key_list: List[str], filter_value_list: List[str] ): """Auto retrieval function. Performs auto-retrieval from a vector database, and then applies a set of filters. """ query = query or "Query" exact_match_filters = [ ExactMatchFilter(key=k, value=v) for k, v in zip(filter_key_list, filter_value_list) ] retriever = VectorIndexRetriever( index, filters=MetadataFilters(filters=exact_match_filters), top_k=3 ) query_engine = RetrieverQueryEngine.from_args(retriever, service_context=service_context) response = query_engine.query(query) return str(response) vector_store_info = VectorStoreInfo( content_info="semantic information about movies", metadata_info=[MetadataInfo( name="title", type="str", description="title of the movie, one of [John Wick (film), John Wick: Chapter 2, John Wick: Chapter 3 – Parabellum, John Wick: Chapter 4]", )] ) description = f"""\ Use this tool to look up semantic information about films. The vector database schema is given below: {vector_store_info.json()} """ auto_retrieve_tool = FunctionTool.from_defaults( fn=auto_retrieve_fn, name="semantic-film-info", description=description, fn_schema=AutoRetrieveModel ) return auto_retrieve_tool def create_sql_agent(service_context): engine = create_engine("sqlite+pysqlite:///:memory:") for i in range(1,5): fn = os.path.join('wick_tables',f'jw{i}.csv') df = pd.read_csv(fn) df.to_sql( f"John Wick {i}", engine ) sql_database = SQLDatabase( engine=engine, include_tables=["John Wick 1", "John Wick 2", "John Wick 3", "John Wick 4"] ) sql_query_engine = NLSQLTableQueryEngine( sql_database=sql_database, tables=["John Wick 1", "John Wick 2", "John Wick 3", "John Wick 4"], service_context=service_context ) sql_tool = QueryEngineTool.from_defaults( query_engine=sql_query_engine, name="sql-query", description=( "Useful for translating a natrual language query into a SQL query over a table containing: " "John Wick 1, containing information related to reviews of the first John Wick movie call 'John Wick'" "John Wick 2, containing information related to reviews of the second John Wick movie call 'John Wick: Chapter 2'" "John Wick 3, containing information related to reviews of the third John Wick movie call 'John Wick: Chapter 3 - Parabellum'" "John Wick 4, containing information related to reviews of the forth John Wick movie call 'John Wick: Chapter 4'" ), ) return sql_tool welcome_message = "Welcome to the John Wick RAQA demo! Ask me anything about the John Wick movies." @cl.on_chat_start # marks a function that will be executed at the start of a user session async def start_chat(): # Create the service context embed_model = OpenAIEmbedding() chunk_size = 500 llm = OpenAI( temperature=0, model='gpt-4-1106-preview', streaming=True ) service_context = ServiceContext.from_defaults( llm=llm, chunk_size=chunk_size, embed_model=embed_model, ) auto_retrieve_tool = create_semantic_agent(service_context) sql_tool = create_sql_agent(service_context) ''' agent = OpenAIAgent.from_tools( tools=[auto_retrieve_tool, sql_tool], ) ''' agent = OpenAIAgent.from_tools( tools=[sql_tool, auto_retrieve_tool], ) cl.user_session.set("agent", agent) await cl.Message(content=welcome_message).send() @cl.on_message # marks a function that should be run each time the chatbot receives a message from a user async def main(message: cl.Message): agent = cl.user_session.get("agent") res = await agent.achat(message.content) answer = str(res) await cl.Message(content=answer).send()