ray commited on
Commit
693929a
1 Parent(s): dfc6dc5

factor out index builder from chatbot

Browse files
Files changed (2) hide show
  1. app.py +61 -37
  2. chatbot.py +49 -30
app.py CHANGED
@@ -1,35 +1,44 @@
1
- import re
2
- from typing import List
3
  import gradio as gr
4
  import openai
5
  import os
6
  from dotenv import load_dotenv
7
  import phoenix as px
8
  import llama_index
9
- from llama_index import OpenAIEmbedding, Prompt, ServiceContext, VectorStoreIndex, SimpleDirectoryReader
10
  from llama_index.chat_engine.types import ChatMode
11
- from llama_index.llms import ChatMessage, MessageRole, OpenAI
12
  from llama_index.vector_stores.qdrant import QdrantVectorStore
13
  from llama_index.text_splitter import SentenceSplitter
14
  from llama_index.extractors import TitleExtractor
15
  from llama_index.ingestion import IngestionPipeline
16
  from chat_template import CHAT_TEXT_QA_PROMPT
17
- from chatbot import Chatbot, ChatbotVersion
 
18
  from custom_io import UnstructuredReader, default_file_metadata_func
19
  from qdrant import client as qdrantClient
 
20
 
21
- load_dotenv()
22
- openai.api_key = os.getenv("OPENAI_API_KEY")
23
 
24
 
25
- class AwesumCareChatbot(Chatbot):
26
- DENIED_ANSWER_PROMPT = ""
27
- SYSTEM_PROMPT = ""
28
- CHAT_EXAMPLES = [
29
- "什麼是安心三寶?",
30
- "點樣立平安紙?"
31
- ]
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def _load_doucments(self):
34
  dir_reader = SimpleDirectoryReader('./awesumcare_data', file_extractor={
35
  ".pdf": UnstructuredReader(),
@@ -41,14 +50,9 @@ class AwesumCareChatbot(Chatbot):
41
  file_metadata=default_file_metadata_func)
42
 
43
  self.documents = dir_reader.load_data()
44
- super()._load_doucments()
45
 
46
  def _setup_service_context(self):
47
- self.service_context = ServiceContext.from_defaults(
48
- chunk_size=self.chunk_size,
49
- llm=self.llm,
50
- embed_model=self.embed_model
51
- )
52
  super()._setup_service_context()
53
 
54
  def _setup_vector_store(self):
@@ -57,21 +61,34 @@ class AwesumCareChatbot(Chatbot):
57
  super()._setup_vector_store()
58
 
59
  def _setup_index(self):
60
- if self.vdb_collection_name in [col.name for col in qdrantClient.get_collections().collections] and qdrantClient.get_collection(self.vdb_collection_name).vectors_count > 0:
61
- self.index = VectorStoreIndex.from_vector_store(
62
- self.vector_store, service_context=self.service_context)
63
  print("set up index from vector store")
64
  return
65
  pipeline = IngestionPipeline(
66
  transformations=[
67
  SentenceSplitter(),
68
- OpenAIEmbedding(),
69
  ],
70
  vector_store=self.vector_store,
71
  )
72
  pipeline.run(documents=self.documents)
73
- self.index = VectorStoreIndex.from_vector_store(
74
- self.vector_store, service_context=self.service_context)
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  super()._setup_index()
76
 
77
  # def _setup_index(self):
@@ -81,18 +98,23 @@ class AwesumCareChatbot(Chatbot):
81
  # )
82
  # super()._setup_index()
83
 
 
 
 
 
 
 
 
 
 
 
 
84
  def _setup_chat_engine(self):
85
  # testing #
86
  from llama_index.agent import OpenAIAgent
87
- from llama_index.tools.query_engine import QueryEngineTool
88
-
89
- query_engine = self.index.as_query_engine(
90
- text_qa_template=CHAT_TEXT_QA_PROMPT)
91
- query_engine_tool = QueryEngineTool.from_defaults(
92
- query_engine=query_engine)
93
  self.chat_engine = OpenAIAgent.from_tools(
94
- tools=[query_engine_tool],
95
- llm=self.service_context.llm,
96
  similarity_top_k=1,
97
  verbose=True
98
  )
@@ -106,9 +128,11 @@ class AwesumCareChatbot(Chatbot):
106
 
107
 
108
  # gpt-3.5-turbo-1106, gpt-4-1106-preview
109
- awesum_chatbot = AwesumCareChatbot(ChatbotVersion.CHATGPT_35.value,
110
- chunk_size=2048,
111
- vdb_collection_name="v2")
 
 
112
 
113
 
114
  def vote(data: gr.LikeData):
 
 
 
1
  import gradio as gr
2
  import openai
3
  import os
4
  from dotenv import load_dotenv
5
  import phoenix as px
6
  import llama_index
7
+ from llama_index import Prompt, ServiceContext, VectorStoreIndex, SimpleDirectoryReader
8
  from llama_index.chat_engine.types import ChatMode
9
+ from llama_index.llms import ChatMessage, MessageRole
10
  from llama_index.vector_stores.qdrant import QdrantVectorStore
11
  from llama_index.text_splitter import SentenceSplitter
12
  from llama_index.extractors import TitleExtractor
13
  from llama_index.ingestion import IngestionPipeline
14
  from chat_template import CHAT_TEXT_QA_PROMPT
15
+ from schemas import ChatbotVersion, ServiceProvider
16
+ from chatbot import Chatbot, IndexBuilder
17
  from custom_io import UnstructuredReader, default_file_metadata_func
18
  from qdrant import client as qdrantClient
19
+ from llama_index import set_global_service_context
20
 
21
+ from service_provider_config import get_service_provider_config
 
22
 
23
 
24
+ # initial service setup
25
+ px.launch_app()
26
+ llama_index.set_global_handler("arize_phoenix")
 
 
 
 
27
 
28
+ load_dotenv()
29
+ openai.api_key = os.getenv("OPENAI_API_KEY")
30
+ CHUNK_SIZE = 1024
31
+ LLM, EMBED_MODEL = get_service_provider_config(
32
+ service_provider=ServiceProvider.OPENAI)
33
+ service_context = ServiceContext.from_defaults(
34
+ chunk_size=CHUNK_SIZE,
35
+ llm=LLM,
36
+ embed_model=EMBED_MODEL,
37
+ )
38
+ set_global_service_context(service_context)
39
+
40
+
41
+ class AwesumIndexBuilder(IndexBuilder):
42
  def _load_doucments(self):
43
  dir_reader = SimpleDirectoryReader('./awesumcare_data', file_extractor={
44
  ".pdf": UnstructuredReader(),
 
50
  file_metadata=default_file_metadata_func)
51
 
52
  self.documents = dir_reader.load_data()
53
+ print(f"Loaded {len(self.documents)} docs")
54
 
55
  def _setup_service_context(self):
 
 
 
 
 
56
  super()._setup_service_context()
57
 
58
  def _setup_vector_store(self):
 
61
  super()._setup_vector_store()
62
 
63
  def _setup_index(self):
64
+ super()._setup_index()
65
+ if self.is_load_from_vector_store:
66
+ self.index = VectorStoreIndex.from_vector_store(self.vector_store)
67
  print("set up index from vector store")
68
  return
69
  pipeline = IngestionPipeline(
70
  transformations=[
71
  SentenceSplitter(),
72
+ EMBED_MODEL,
73
  ],
74
  vector_store=self.vector_store,
75
  )
76
  pipeline.run(documents=self.documents)
77
+ self.index = VectorStoreIndex.from_vector_store(self.vector_store)
78
+
79
+
80
+ class AwesumCareChatbot(Chatbot):
81
+ DENIED_ANSWER_PROMPT = ""
82
+ SYSTEM_PROMPT = ""
83
+ CHAT_EXAMPLES = [
84
+ "什麼是安心三寶?",
85
+ "點樣立平安紙?"
86
+ ]
87
+
88
+ def _setup_observer(self):
89
+ pass
90
+
91
+ def _setup_index(self):
92
  super()._setup_index()
93
 
94
  # def _setup_index(self):
 
98
  # )
99
  # super()._setup_index()
100
 
101
+ def _setup_query_engine(self):
102
+ super()._setup_query_engine()
103
+ self.query_engine = self.index.as_query_engine(
104
+ text_qa_template=CHAT_TEXT_QA_PROMPT)
105
+
106
+ def _setup_tools(self):
107
+ from llama_index.tools.query_engine import QueryEngineTool
108
+ self.tools = QueryEngineTool.from_defaults(
109
+ query_engine=self.query_engine)
110
+ return super()._setup_tools()
111
+
112
  def _setup_chat_engine(self):
113
  # testing #
114
  from llama_index.agent import OpenAIAgent
 
 
 
 
 
 
115
  self.chat_engine = OpenAIAgent.from_tools(
116
+ tools=[self.tools],
117
+ llm=LLM,
118
  similarity_top_k=1,
119
  verbose=True
120
  )
 
128
 
129
 
130
  # gpt-3.5-turbo-1106, gpt-4-1106-preview
131
+ awesum_chatbot = AwesumCareChatbot(model_name=ChatbotVersion.CHATGPT_35.value,
132
+ index_builder=AwesumIndexBuilder(
133
+ vdb_collection_name="demo-v0",
134
+ is_load_from_vector_store=True)
135
+ )
136
 
137
 
138
  def vote(data: gr.LikeData):
chatbot.py CHANGED
@@ -13,55 +13,82 @@ from llama_index.llms import ChatMessage, MessageRole, OpenAI
13
  load_dotenv()
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class Chatbot:
17
  SYSTEM_PROMPT = ""
18
  DENIED_ANSWER_PROMPT = ""
19
  CHAT_EXAMPLES = []
20
 
21
- def __init__(self, model_name, chunk_size, vdb_collection_name="test_store"):
22
  self.model_name = model_name
23
- self.llm = OpenAI(model=self.model_name)
24
- self.embed_model = OpenAIEmbedding()
25
- self.chunk_size = chunk_size
26
 
27
  self.documents = None
28
  self.index = None
29
  self.chat_engine = None
30
  self.service_context = None
31
  self.vector_store = None
32
- self.vdb_collection_name = vdb_collection_name
33
 
34
  self._setup_chatbot()
35
 
36
  def _setup_chatbot(self):
37
- self._setup_observer()
38
- self._setup_service_context()
39
- self._setup_vector_store()
40
- self._load_doucments()
41
  self._setup_index()
 
 
42
  self._setup_chat_engine()
43
 
44
  def _setup_observer(self):
45
  px.launch_app()
46
  llama_index.set_global_handler("arize_phoenix")
47
 
48
- def _load_doucments(self):
49
- pass
50
- print(f"Loaded {len(self.documents)} docs")
51
-
52
- def _setup_service_context(self):
53
- pass
54
- print("Setup service context...")
55
 
56
- def _setup_vector_store(self):
 
 
57
  pass
58
- print("Setup vector store...")
59
 
60
- def _setup_index(self):
61
- if self.documents is None:
62
- raise ValueError("No documents loaded")
63
  pass
64
- print("Built index...")
65
 
66
  def _setup_chat_engine(self):
67
  if self.index is None:
@@ -141,11 +168,3 @@ class Chatbot:
141
  # For 'Vanilla ChatGPT' - No system prompt
142
  def predict_vanilla_chatgpt(self, message, history):
143
  yield from self._invoke_chatgpt(history, message)
144
-
145
-
146
- # make a enum of chatbot type and string
147
-
148
-
149
- class ChatbotVersion(str, Enum):
150
- CHATGPT_35 = "gpt-3.5-turbo-1106"
151
- CHATGPT_4 = "gpt-4-1106-preview"
 
13
  load_dotenv()
14
 
15
 
16
+ class IndexBuilder:
17
+ def __init__(self, vdb_collection_name, is_load_from_vector_store=False):
18
+ self.documents = None
19
+ self.vdb_collection_name = vdb_collection_name
20
+ self.index = None
21
+ self.is_load_from_vector_store = is_load_from_vector_store
22
+ self.build_index()
23
+
24
+ def _load_doucments(self):
25
+ pass
26
+
27
+ def _setup_service_context(self):
28
+ print("Using global service context...")
29
+
30
+ def _setup_vector_store(self):
31
+ print("Setup vector store...")
32
+
33
+ def _setup_index(self):
34
+ if not self.is_load_from_vector_store and self.documents is None:
35
+ raise ValueError("No documents provided for index building.")
36
+ print("Building Index")
37
+
38
+ def build_index(self):
39
+ if self.is_load_from_vector_store:
40
+ self._setup_service_context()
41
+ self._setup_vector_store()
42
+ self._setup_index()
43
+ return
44
+ self._load_doucments()
45
+ self._setup_service_context()
46
+ self._setup_vector_store()
47
+ self._setup_index()
48
+
49
+
50
  class Chatbot:
51
  SYSTEM_PROMPT = ""
52
  DENIED_ANSWER_PROMPT = ""
53
  CHAT_EXAMPLES = []
54
 
55
+ def __init__(self, model_name, index_builder: IndexBuilder):
56
  self.model_name = model_name
57
+ self.index_builder = index_builder
 
 
58
 
59
  self.documents = None
60
  self.index = None
61
  self.chat_engine = None
62
  self.service_context = None
63
  self.vector_store = None
64
+ self.tools = None
65
 
66
  self._setup_chatbot()
67
 
68
  def _setup_chatbot(self):
69
+ # self._setup_observer()
 
 
 
70
  self._setup_index()
71
+ self._setup_query_engine()
72
+ self._setup_tools()
73
  self._setup_chat_engine()
74
 
75
  def _setup_observer(self):
76
  px.launch_app()
77
  llama_index.set_global_handler("arize_phoenix")
78
 
79
+ def _setup_index(self):
80
+ self.index = self.index_builder.index
81
+ print("Inherited index builder")
 
 
 
 
82
 
83
+ def _setup_query_engine(self):
84
+ if self.index is None:
85
+ raise ValueError("No index built")
86
  pass
87
+ print("Setup query engine...")
88
 
89
+ def _setup_tools(self):
 
 
90
  pass
91
+ print("Setup tools...")
92
 
93
  def _setup_chat_engine(self):
94
  if self.index is None:
 
168
  # For 'Vanilla ChatGPT' - No system prompt
169
  def predict_vanilla_chatgpt(self, message, history):
170
  yield from self._invoke_chatgpt(history, message)