File size: 5,664 Bytes
679f269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import chainlit as cl
import llama_index
from llama_index import set_global_handler
from llama_index.embeddings import OpenAIEmbedding
from llama_index import ServiceContext
from llama_index.llms import OpenAI
from llama_index import SimpleDirectoryReader
from llama_index.ingestion import IngestionPipeline
from llama_index.node_parser import TokenTextSplitter
from llama_index import load_index_from_storage
from llama_index.tools import FunctionTool
from llama_index.vector_stores.types import (
    VectorStoreInfo,
    MetadataInfo,
    ExactMatchFilter,
    MetadataFilters,
)
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine
from typing import List
from AutoRetrieveModel import AutoRetrieveModel
from llama_index.agent import OpenAIAgent
from sqlalchemy import create_engine
from llama_index import SQLDatabase
from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine
from llama_index.tools.query_engine import QueryEngineTool
import pandas as pd
import openai

set_global_handler("wandb", run_args={"project": "llamaindex-demo-v1"})
wandb_callback = llama_index.global_handler

def create_semantic_agent(service_context):
    # Load in wikipedia index
    storage_context = wandb_callback.load_storage_context(
        artifact_url="jfreeman/llamaindex-demo-v1/wiki-index:v1"
    )
    
    index = load_index_from_storage(storage_context, service_context=service_context)
    
    def auto_retrieve_fn(
        query: str, filter_key_list: List[str], filter_value_list: List[str]
    ):
        """Auto retrieval function.

        Performs auto-retrieval from a vector database, and then applies a set of filters.

        """
        query = query or "Query"

        exact_match_filters = [
            ExactMatchFilter(key=k, value=v)
            for k, v in zip(filter_key_list, filter_value_list)
        ]
        retriever = VectorIndexRetriever(
            index, filters=MetadataFilters(filters=exact_match_filters), top_k=3
        )
        query_engine = RetrieverQueryEngine.from_args(retriever, service_context=service_context)

        response = query_engine.query(query)
        return str(response)
    
    vector_store_info = VectorStoreInfo(
        content_info="semantic information about movies",
        metadata_info=[MetadataInfo(
            name="title",
            type="str",
            description="title of the movie, one of [John Wick (film), John Wick: Chapter 2, John Wick: Chapter 3 – Parabellum, John Wick: Chapter 4]",
        )]
    )
    description = f"""\
    Use this tool to look up semantic information about films.
    The vector database schema is given below:
    {vector_store_info.json()}
    """
    auto_retrieve_tool = FunctionTool.from_defaults(
        fn=auto_retrieve_fn,
        name="semantic-film-info",
        description=description,
        fn_schema=AutoRetrieveModel
    )
    return auto_retrieve_tool

def create_sql_agent(service_context):
    engine = create_engine("sqlite+pysqlite:///:memory:")
    
    for i in range(1,5):
        fn = os.path.join('wick_tables',f'jw{i}.csv')
        df = pd.read_csv(fn)
        df.to_sql(
            f"John Wick {i}",
            engine
        )
 
    sql_database = SQLDatabase(
        engine=engine,
        include_tables=["John Wick 1", "John Wick 2", "John Wick 3", "John Wick 4"]
    )
    
    sql_query_engine = NLSQLTableQueryEngine(
        sql_database=sql_database,
        tables=["John Wick 1", "John Wick 2", "John Wick 3", "John Wick 4"],
        service_context=service_context
    )
    
    sql_tool = QueryEngineTool.from_defaults(
        query_engine=sql_query_engine,
        name="sql-query",
        description=(
            "Useful for translating a natrual language query into a SQL query over a table containing: "
            "John Wick 1, containing information related to reviews of the first John Wick movie call 'John Wick'"
            "John Wick 2, containing information related to reviews of the second John Wick movie call 'John Wick: Chapter 2'"
            "John Wick 3, containing information related to reviews of the third John Wick movie call 'John Wick: Chapter 3 - Parabellum'"
            "John Wick 4, containing information related to reviews of the forth John Wick movie call 'John Wick: Chapter 4'"
        ),
    )
    return sql_tool
    
welcome_message = "Welcome to the John Wick RAQA demo! Ask me anything about the John Wick movies."
@cl.on_chat_start  # marks a function that will be executed at the start of a user session
async def start_chat():
    # Create the service context
    embed_model = OpenAIEmbedding()
    chunk_size = 500
    llm = OpenAI(
        temperature=0,
        model='gpt-4-1106-preview',
        streaming=True
    )
    service_context = ServiceContext.from_defaults(
        llm=llm,
        chunk_size=chunk_size,
        embed_model=embed_model,
    )

    auto_retrieve_tool = create_semantic_agent(service_context)
    sql_tool = create_sql_agent(service_context)
    '''
    agent = OpenAIAgent.from_tools(
        tools=[auto_retrieve_tool, sql_tool],
    )
    '''
    agent = OpenAIAgent.from_tools(
        tools=[sql_tool, auto_retrieve_tool],
    )
    cl.user_session.set("agent", agent)
    await cl.Message(content=welcome_message).send()

@cl.on_message  # marks a function that should be run each time the chatbot receives a message from a user
async def main(message: cl.Message):
    agent = cl.user_session.get("agent")
    res = await agent.achat(message.content)
    answer = str(res)
    await cl.Message(content=answer).send()