raul-padua commited on
Commit
a2c1b0b
1 Parent(s): 9227993

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -99
app.py CHANGED
@@ -1,110 +1,165 @@
1
- import chainlit as cl
2
- from langchain.embeddings.openai import OpenAIEmbeddings
3
- from langchain.document_loaders.csv_loader import CSVLoader
4
- from langchain.embeddings import CacheBackedEmbeddings
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain.vectorstores import FAISS
7
- from langchain.chains import RetrievalQA
8
- from langchain.chat_models import ChatOpenAI
9
- from langchain.storage import LocalFileStore
10
- from langchain.prompts.chat import (
11
- ChatPromptTemplate,
12
- SystemMessagePromptTemplate,
13
- HumanMessagePromptTemplate,
 
 
 
 
 
 
 
 
 
 
 
 
14
  )
15
- import chainlit as cl
16
-
17
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
18
-
19
- system_template = """
20
- Use the following pieces of context to answer the user's question.
21
- Please respond as if you were Ken from the movie Barbie. Ken is a well-meaning but naive character who loves to Beach. He talks like a typical Californian Beach Bro, but he doesn't use the word "Dude" so much.
22
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
23
- You can make inferences based on the context as long as it still faithfully represents the feedback.
24
-
25
- Example of your response should be:
26
-
27
- ```
28
- The answer is foo
29
- ```
30
-
31
- Begin!
32
- ----------------
33
- {context}"""
34
-
35
- messages = [
36
- SystemMessagePromptTemplate.from_template(system_template),
37
- HumanMessagePromptTemplate.from_template("{question}"),
38
- ]
39
- prompt = ChatPromptTemplate(messages=messages)
40
- chain_type_kwargs = {"prompt": prompt}
41
-
42
- @cl.author_rename
43
- def rename(orig_author: str):
44
- rename_dict = {"RetrievalQA": "Consulting The Kens"}
45
- return rename_dict.get(orig_author, orig_author)
46
-
47
- @cl.on_chat_start
48
- async def init():
49
- msg = cl.Message(content=f"Building Index...")
50
- await msg.send()
51
-
52
- # build FAISS index from csv
53
- loader = CSVLoader(file_path="./data/barbie.csv", source_column="Review_Url")
54
- data = loader.load()
55
- documents = text_splitter.transform_documents(data)
56
- store = LocalFileStore("./cache/")
57
- core_embeddings_model = OpenAIEmbeddings()
58
- embedder = CacheBackedEmbeddings.from_bytes_store(
59
- core_embeddings_model, store, namespace=core_embeddings_model.model
60
  )
61
- # make async docsearch
62
- docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)
63
-
64
- chain = RetrievalQA.from_chain_type(
65
- ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
66
- chain_type="stuff",
67
- return_source_documents=True,
68
- retriever=docsearch.as_retriever(),
69
- chain_type_kwargs = {"prompt": prompt}
70
  )
71
 
72
- msg.content = f"Index built!"
73
- await msg.send()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- cl.user_session.set("chain", chain)
 
 
 
 
 
76
 
 
 
 
 
77
 
78
- @cl.on_message
79
- async def main(message):
80
- chain = cl.user_session.get("chain")
81
- cb = cl.AsyncLangchainCallbackHandler(
82
- stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
 
 
 
 
 
83
  )
84
- cb.answer_reached = True
85
- res = await chain.acall(message, callbacks=[cb], )
86
-
87
- answer = res["result"]
88
- source_elements = []
89
- visited_sources = set()
90
-
91
- # Get the documents from the user session
92
- docs = res["source_documents"]
93
- metadatas = [doc.metadata for doc in docs]
94
- all_sources = [m["source"] for m in metadatas]
95
-
96
- for source in all_sources:
97
- if source in visited_sources:
98
- continue
99
- visited_sources.add(source)
100
- # Create the text element referenced in the message
101
- source_elements.append(
102
- cl.Text(content="https://www.imdb.com" + source, name="Review URL")
103
- )
104
 
105
- if source_elements:
106
- answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  else:
108
- answer += "\nNo sources found"
 
 
 
 
 
109
 
110
- await cl.Message(content=answer, elements=source_elements).send()
 
1
+ from llama_index import ServiceContext, SimpleNodeParser, TokenTextSplitter, OpenAI, OpenAIEmbedding
2
+ from llama_index.vector_stores import ChromaVectorStore
3
+ from llama_index.storage.storage_context import StorageContext
4
+ from llama_index import VectorStoreIndex, WikipediaReader
5
+ from llama_index.tools import FunctionTool
6
+ from llama_index.vector_stores.types import VectorStoreInfo, MetadataInfo, ExactMatchFilter, MetadataFilters
7
+ from llama_index.retrievers import VectorIndexRetriever
8
+ from llama_index.query_engine import RetrieverQueryEngine
9
+ from typing import List, Tuple, Any
10
+ from pydantic import BaseModel, Field
11
+ import chromadb
12
+ import pandas as pd
13
+ from sqlalchemy import create_engine
14
+ from llama_index import SQLDatabase, NLSQLTableQueryEngine, QueryEngineTool
15
+ from llama_index.openai_agent import OpenAIAgent
16
+ from chainlit import ChainLit
17
+
18
+ # Embedding Model and Low-level model
19
+ embed_model = OpenAIEmbedding()
20
+ chunk_size = 1000
21
+ chunk_overlap = 100
22
+ llm = OpenAI(
23
+ temperature=0,
24
+ model="gpt-4-32k",
25
+ streaming=True
26
  )
27
+ service_context = ServiceContext.from_defaults(
28
+ llm=llm,
29
+ chunk_size=chunk_size,
30
+ embed_model=embed_model
31
+ )
32
+ text_splitter = TokenTextSplitter(
33
+ chunk_size=chunk_size,
34
+ chunk_overlap=chunk_overlap
35
+ )
36
+ node_parser = SimpleNodeParser(text_splitter=text_splitter)
37
+
38
+ # Vector Storage and Context
39
+ chroma_client = chromadb.Client()
40
+ chroma_collection = chroma_client.create_collection("wikipedia_barbie_opp")
41
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
42
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
43
+
44
+ # Your Wikipedia docs retrieval
45
+ movie_list = ["Barbie (film)", "Oppenheimer (film)"]
46
+ wiki_docs = WikipediaReader().load_data(pages=movie_list, auto_suggest=False)
47
+
48
+ # Parsing and storing vectors
49
+ wiki_vector_index = VectorStoreIndex([], storage_context=storage_context, service_context=service_context)
50
+ for movie, wiki_doc in zip(movie_list, wiki_docs):
51
+ nodes = node_parser.get_nodes_from_documents([wiki_doc])
52
+ for node in nodes:
53
+ node.metadata = {"title": movie}
54
+ wiki_vector_index.insert_nodes(nodes)
55
+
56
+ # Defining the tools for vector search and SQL query
57
+ top_k = 3
58
+ vector_store_info = VectorStoreInfo(
59
+ content_info="semantic information about movies",
60
+ metadata_info=[MetadataInfo(
61
+ name="title",
62
+ type="str",
63
+ description="title of the movie, one of [Barbie (film), Oppenheimer (film)]",
64
+ )]
65
+ )
66
+
67
+ # Create PyDantic model for auto retrieval
68
+ class AutoRetrieveModel(BaseModel):
69
+ query: str = Field(..., description="natural language query string")
70
+ filter_key_list: List[str] = Field(
71
+ ..., description="List of metadata filter field names"
72
  )
73
+ filter_value_list: List[str] = Field(
74
+ ...,
75
+ description=(
76
+ "List of metadata filter field values (corresponding to names specified in filter_key_list)"
77
+ )
 
 
 
 
78
  )
79
 
80
+ def auto_retrieve_fn(query: str, filter_key_list: List[str], filter_value_list: List[str]):
81
+ exact_match_filters = [
82
+ ExactMatchFilter(key=k, value=v)
83
+ for k, v in zip(filter_key_list, filter_value_list)
84
+ ]
85
+ retriever = VectorIndexRetriever(
86
+ wiki_vector_index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k
87
+ )
88
+ query_engine = RetrieverQueryEngine.from_args(retriever)
89
+ response = query_engine.query(query)
90
+ return str(response)
91
+
92
+ description = f"""\
93
+ Use this tool to look up semantic information about films.
94
+ The vector database schema is given below:
95
+ {vector_store_info.json()}
96
+ """
97
+
98
+ auto_retrieve_tool = FunctionTool.from_defaults(
99
+ fn=auto_retrieve_fn,
100
+ name="Auto_Retriever",
101
+ description=description,
102
+ fn_schema=AutoRetrieveModel
103
+ )
104
 
105
+ # SQL setup and tool definition
106
+ barbie_df = pd.read_csv("barbie_data/barbie.csv")
107
+ oppenheimer_df = pd.read_csv("oppenheimer_data/oppenheimer.csv")
108
+ engine = create_engine("sqlite+pysqlite:///:memory:")
109
+ barbie_df.to_sql(name='barbie', con=engine)
110
+ oppenheimer_df.to_sql(name='oppenheimer', con=engine)
111
 
112
+ sql_database = SQLDatabase(
113
+ engine=engine,
114
+ include_tables=['barbie', 'oppenheimer']
115
+ )
116
 
117
+ sql_query_engine = NLSQLTableQueryEngine(
118
+ sql_database=sql_database,
119
+ tables=['barbie', 'oppenheimer']
120
+ )
121
+
122
+ sql_tool = QueryEngineTool.from_defaults(
123
+ query_engine=sql_query_engine,
124
+ name="Natural_Language_to_SQL_Tool",
125
+ description=(
126
+ "Useful for translating a natural language query into a SQL query."
127
  )
128
+ )
129
+
130
+ # Combining both tools into a single OpenAI Agent
131
+ barbenheimer_agent = OpenAIAgent.from_tools(
132
+ tools=[auto_retrieve_tool, sql_tool]
133
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ # Initialize the ChainLit app
136
+ cl = ChainLit()
137
+
138
+ # On-Message Function
139
+ @cl.on_message
140
+ def handle_message(message: str, sender: str) -> Tuple[str, Any]:
141
+ query_result = barbenheimer_agent.query(
142
+ query=message,
143
+ user_id=sender
144
+ )
145
+
146
+ # Extracting relevant information from the query result
147
+ tool_name = query_result.tool_name
148
+ response = query_result.response
149
+
150
+ if tool_name == "Auto_Retriever":
151
+ # Processing for semantic information retrieval
152
+ return "Auto_Retriever", f"Semantic Information:\n{response}"
153
+
154
+ elif tool_name == "Natural_Language_to_SQL_Tool":
155
+ # Processing for SQL-based information
156
+ return "Natural_Language_to_SQL_Tool", f"SQL Query Result:\n{response}"
157
+
158
  else:
159
+ # Handling unrecognized tool queries
160
+ return "Unknown", "I couldn't understand your request."
161
+
162
+ # Running the app
163
+ if __name__ == '__main__':
164
+ cl.run()
165