raul-padua commited on
Commit
b72c08a
1 Parent(s): 53ed07c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -155
app.py CHANGED
@@ -1,165 +1,81 @@
1
- from llama_index import ServiceContext, SimpleNodeParser, TokenTextSplitter, OpenAI, OpenAIEmbedding
2
- from llama_index.vector_stores import ChromaVectorStore
3
- from llama_index.storage.storage_context import StorageContext
4
- from llama_index import VectorStoreIndex, WikipediaReader
5
- from llama_index.tools import FunctionTool
6
- from llama_index.vector_stores.types import VectorStoreInfo, MetadataInfo, ExactMatchFilter, MetadataFilters
7
- from llama_index.retrievers import VectorIndexRetriever
8
- from llama_index.query_engine import RetrieverQueryEngine
9
- from typing import List, Tuple, Any
10
- from pydantic import BaseModel, Field
11
- import chromadb
12
- import pandas as pd
13
- from sqlalchemy import create_engine
14
- from llama_index import SQLDatabase, NLSQLTableQueryEngine, QueryEngineTool
15
- from llama_index.openai_agent import OpenAIAgent
16
- from chainlit import ChainLit
17
-
18
- # Embedding Model and Low-level model
19
- embed_model = OpenAIEmbedding()
20
- chunk_size = 1000
21
- chunk_overlap = 100
22
- llm = OpenAI(
23
- temperature=0,
24
- model="gpt-4-32k",
25
- streaming=True
26
- )
27
- service_context = ServiceContext.from_defaults(
28
- llm=llm,
29
- chunk_size=chunk_size,
30
- embed_model=embed_model
31
  )
32
- text_splitter = TokenTextSplitter(
33
- chunk_size=chunk_size,
34
- chunk_overlap=chunk_overlap
35
- )
36
- node_parser = SimpleNodeParser(text_splitter=text_splitter)
37
-
38
- # Vector Storage and Context
39
- chroma_client = chromadb.Client()
40
- chroma_collection = chroma_client.create_collection("wikipedia_barbie_opp")
41
- vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
42
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
43
-
44
- # Your Wikipedia docs retrieval
45
- movie_list = ["Barbie (film)", "Oppenheimer (film)"]
46
- wiki_docs = WikipediaReader().load_data(pages=movie_list, auto_suggest=False)
47
-
48
- # Parsing and storing vectors
49
- wiki_vector_index = VectorStoreIndex([], storage_context=storage_context, service_context=service_context)
50
- for movie, wiki_doc in zip(movie_list, wiki_docs):
51
- nodes = node_parser.get_nodes_from_documents([wiki_doc])
52
- for node in nodes:
53
- node.metadata = {"title": movie}
54
- wiki_vector_index.insert_nodes(nodes)
55
-
56
- # Defining the tools for vector search and SQL query
57
- top_k = 3
58
- vector_store_info = VectorStoreInfo(
59
- content_info="semantic information about movies",
60
- metadata_info=[MetadataInfo(
61
- name="title",
62
- type="str",
63
- description="title of the movie, one of [Barbie (film), Oppenheimer (film)]",
64
- )]
65
- )
66
-
67
- # Create PyDantic model for auto retrieval
68
- class AutoRetrieveModel(BaseModel):
69
- query: str = Field(..., description="natural language query string")
70
- filter_key_list: List[str] = Field(
71
- ..., description="List of metadata filter field names"
72
- )
73
- filter_value_list: List[str] = Field(
74
- ...,
75
- description=(
76
- "List of metadata filter field values (corresponding to names specified in filter_key_list)"
77
- )
78
  )
79
-
80
- def auto_retrieve_fn(query: str, filter_key_list: List[str], filter_value_list: List[str]):
81
- exact_match_filters = [
82
- ExactMatchFilter(key=k, value=v)
83
- for k, v in zip(filter_key_list, filter_value_list)
84
- ]
85
- retriever = VectorIndexRetriever(
86
- wiki_vector_index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k
87
  )
88
- query_engine = RetrieverQueryEngine.from_args(retriever)
89
- response = query_engine.query(query)
90
- return str(response)
91
-
92
- description = f"""\
93
- Use this tool to look up semantic information about films.
94
- The vector database schema is given below:
95
- {vector_store_info.json()}
96
- """
97
-
98
- auto_retrieve_tool = FunctionTool.from_defaults(
99
- fn=auto_retrieve_fn,
100
- name="Auto_Retriever",
101
- description=description,
102
- fn_schema=AutoRetrieveModel
103
- )
104
-
105
- # SQL setup and tool definition
106
- barbie_df = pd.read_csv("barbie_data/barbie.csv")
107
- oppenheimer_df = pd.read_csv("oppenheimer_data/oppenheimer.csv")
108
- engine = create_engine("sqlite+pysqlite:///:memory:")
109
- barbie_df.to_sql(name='barbie', con=engine)
110
- oppenheimer_df.to_sql(name='oppenheimer', con=engine)
111
-
112
- sql_database = SQLDatabase(
113
- engine=engine,
114
- include_tables=['barbie', 'oppenheimer']
115
- )
116
 
117
- sql_query_engine = NLSQLTableQueryEngine(
118
- sql_database=sql_database,
119
- tables=['barbie', 'oppenheimer']
120
- )
121
-
122
- sql_tool = QueryEngineTool.from_defaults(
123
- query_engine=sql_query_engine,
124
- name="Natural_Language_to_SQL_Tool",
125
- description=(
126
- "Useful for translating a natural language query into a SQL query."
127
  )
128
- )
 
129
 
130
- # Combining both tools into a single OpenAI Agent
131
- barbenheimer_agent = OpenAIAgent.from_tools(
132
- tools=[auto_retrieve_tool, sql_tool]
133
- )
134
-
135
- # Initialize the ChainLit app
136
- cl = ChainLit()
137
-
138
- # On-Message Function
139
  @cl.on_message
140
- def handle_message(message: str, sender: str) -> Tuple[str, Any]:
141
- query_result = barbenheimer_agent.query(
142
- query=message,
143
- user_id=sender
144
- )
145
-
146
- # Extracting relevant information from the query result
147
- tool_name = query_result.tool_name
148
- response = query_result.response
149
-
150
- if tool_name == "Auto_Retriever":
151
- # Processing for semantic information retrieval
152
- return "Auto_Retriever", f"Semantic Information:\n{response}"
153
 
154
- elif tool_name == "Natural_Language_to_SQL_Tool":
155
- # Processing for SQL-based information
156
- return "Natural_Language_to_SQL_Tool", f"SQL Query Result:\n{response}"
157
-
158
- else:
159
- # Handling unrecognized tool queries
160
- return "Unknown", "I couldn't understand your request."
 
161
 
162
- # Running the app
163
- if __name__ == '__main__':
164
- cl.run()
165
 
 
1
+ import os
2
+ import openai
3
+ import logging
4
+ from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
5
+ from llama_index.callbacks.base import CallbackManager
6
+ from llama_index import (
7
+ LLMPredictor,
8
+ ServiceContext,
9
+ StorageContext,
10
+ load_index_from_storage,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  )
12
+ from langchain.chat_models import ChatOpenAI
13
+ import chainlit as cl
14
+
15
+ # Set up logging for debugging and monitoring
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Load OpenAI API key
20
+ openai.api_key = os.environ.get("OPENAI_API_KEY")
21
+
22
+ try:
23
+ # Attempt to rebuild storage context and load index
24
+ logger.info("Attempting to load index from storage.")
25
+ storage_context = StorageContext.from_defaults(persist_dir="./storage")
26
+ index = load_index_from_storage(storage_context)
27
+ except Exception as e:
28
+ # If index loading fails, create a new index
29
+ logger.warning(f"Failed to load index from storage: {e}. Creating a new index.")
30
+ from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader
31
+
32
+ documents = SimpleDirectoryReader("./data").load_data()
33
+ index = GPTVectorStoreIndex.from_documents(documents)
34
+ index.storage_context.persist()
35
+ logger.info("New index created and persisted.")
36
+
37
+ @cl.on_chat_start
38
+ async def factory():
39
+ embed_model = OpenAIEmbedding()
40
+ chunk_size = 1000
41
+
42
+ llm_predictor = LLMPredictor(
43
+ llm=ChatOpenAI(
44
+ temperature=0,
45
+ model_name="gpt-4-32k",
46
+ streaming=True,
47
+ ),
 
 
 
 
 
 
 
 
 
 
48
  )
49
+ service_context = ServiceContext.from_defaults(
50
+ llm_predictor=llm_predictor,
51
+ chunk_size=chunk_size,
52
+ callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]),
 
 
 
 
53
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ query_engine = index.as_query_engine(
56
+ service_context=service_context,
57
+ streaming=True,
 
 
 
 
 
 
 
58
  )
59
+ logger.info("Query engine initialized.") # to facilitate debugging and monitoring
60
+ cl.user_session.set("query_engine", query_engine)
61
 
 
 
 
 
 
 
 
 
 
62
  @cl.on_message
63
+ async def main(message):
64
+ try:
65
+ query_engine = cl.user_session.get("query_engine") # type: RetrieverQueryEngine
66
+ logger.info(f"Received message: {message}")
67
+ response = await cl.make_async(query_engine.query)(message)
68
+ response_message = cl.Message(content="")
 
 
 
 
 
 
 
69
 
70
+ for token in response.response_gen:
71
+ await response_message.stream_token(token=token)
72
+
73
+ if response.response_txt:
74
+ response_message.content = response.response_txt
75
+
76
+ await response_message.send()
77
+ logger.info(f"Response sent: {response.response_txt}")
78
 
79
+ except Exception as e:
80
+ logger.error(f"An error occurred while processing the message: {e}")
 
81