jfeng1115 commited on
Commit
17c96aa
1 Parent(s): f867823

add stream pipeline

Browse files
Files changed (2) hide show
  1. aimakerspace/openai_utils/chatmodel.py +7 -0
  2. app.py +17 -1
aimakerspace/openai_utils/chatmodel.py CHANGED
@@ -25,3 +25,10 @@ class ChatOpenAI:
25
  return response.choices[0].message.content
26
 
27
  return response
 
 
 
 
 
 
 
 
25
  return response.choices[0].message.content
26
 
27
  return response
28
+
29
+ def run_stream(self, messages, settings, chainlit_msg, text_only: bool = True):
30
+ async for stream_resp in await openai.Completion.acreate(
31
+ model=self.model_name, prompt=messages, stream=True, **settings
32
+ ):
33
+ token = stream_resp.get("choices")[0].get("text")
34
+ await chainlit_msg.stream_token(token)
app.py CHANGED
@@ -76,6 +76,19 @@ class RetrievalAugmentedQAPipeline:
76
 
77
  return self.llm.run([formatted_system_prompt, formatted_user_prompt])
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  @cl.on_chat_start # marks a function that will be executed at the start of a user session
81
  def start_chat():
@@ -97,7 +110,10 @@ def start_chat():
97
  async def main(message: str):
98
 
99
  qaPipeline = RetrievalAugmentedQAPipeline(vector_db_retriever=vector_db, llm=chat_openai)
100
- qaPipeline.run_pipeline(user_query=message)
 
 
 
101
 
102
 
103
 
 
76
 
77
  return self.llm.run([formatted_system_prompt, formatted_user_prompt])
78
 
79
+ def stream_pipeline(self, user_query: str, msg: cl.Message) -> str:
80
+ context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
81
+
82
+ context_prompt = ""
83
+ for context in context_list:
84
+ context_prompt += context[0] + "\n"
85
+
86
+ formatted_system_prompt = raqa_prompt.create_message(context=context_prompt)
87
+
88
+ formatted_user_prompt = user_prompt.create_message(user_query=user_query)
89
+
90
+ self.llm.stream([formatted_system_prompt, formatted_user_prompt])
91
+
92
 
93
  @cl.on_chat_start # marks a function that will be executed at the start of a user session
94
  def start_chat():
 
110
  async def main(message: str):
111
 
112
  qaPipeline = RetrievalAugmentedQAPipeline(vector_db_retriever=vector_db, llm=chat_openai)
113
+ msg = cl.Message(content="")
114
+
115
+ qaPipeline.stream_pipeline(user_query=message, msg=msg)
116
+ await msg.send()
117
 
118
 
119