add stream pipeline
Browse files- aimakerspace/openai_utils/chatmodel.py +7 -0
- app.py +17 -1
aimakerspace/openai_utils/chatmodel.py
CHANGED
@@ -25,3 +25,10 @@ class ChatOpenAI:
|
|
25 |
return response.choices[0].message.content
|
26 |
|
27 |
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
return response.choices[0].message.content
|
26 |
|
27 |
return response
|
28 |
+
|
29 |
+
def run_stream(self, messages, settings, chainlit_msg, text_only: bool = True):
|
30 |
+
async for stream_resp in await openai.Completion.acreate(
|
31 |
+
model=self.model_name, prompt=messages, stream=True, **settings
|
32 |
+
):
|
33 |
+
token = stream_resp.get("choices")[0].get("text")
|
34 |
+
await chainlit_msg.stream_token(token)
|
app.py
CHANGED
@@ -76,6 +76,19 @@ class RetrievalAugmentedQAPipeline:
|
|
76 |
|
77 |
return self.llm.run([formatted_system_prompt, formatted_user_prompt])
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
@cl.on_chat_start # marks a function that will be executed at the start of a user session
|
81 |
def start_chat():
|
@@ -97,7 +110,10 @@ def start_chat():
|
|
97 |
async def main(message: str):
|
98 |
|
99 |
qaPipeline = RetrievalAugmentedQAPipeline(vector_db_retriever=vector_db, llm=chat_openai)
|
100 |
-
|
|
|
|
|
|
|
101 |
|
102 |
|
103 |
|
|
|
76 |
|
77 |
return self.llm.run([formatted_system_prompt, formatted_user_prompt])
|
78 |
|
79 |
+
def stream_pipeline(self, user_query: str, msg: cl.Message) -> str:
|
80 |
+
context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
|
81 |
+
|
82 |
+
context_prompt = ""
|
83 |
+
for context in context_list:
|
84 |
+
context_prompt += context[0] + "\n"
|
85 |
+
|
86 |
+
formatted_system_prompt = raqa_prompt.create_message(context=context_prompt)
|
87 |
+
|
88 |
+
formatted_user_prompt = user_prompt.create_message(user_query=user_query)
|
89 |
+
|
90 |
+
self.llm.stream([formatted_system_prompt, formatted_user_prompt])
|
91 |
+
|
92 |
|
93 |
@cl.on_chat_start # marks a function that will be executed at the start of a user session
|
94 |
def start_chat():
|
|
|
110 |
async def main(message: str):
|
111 |
|
112 |
qaPipeline = RetrievalAugmentedQAPipeline(vector_db_retriever=vector_db, llm=chat_openai)
|
113 |
+
msg = cl.Message(content="")
|
114 |
+
|
115 |
+
qaPipeline.stream_pipeline(user_query=message, msg=msg)
|
116 |
+
await msg.send()
|
117 |
|
118 |
|
119 |
|