Spaces:
Runtime error
Runtime error
Migrate to stream generator
Browse files- app/main.py +10 -6
- app/rag.py +6 -4
app/main.py
CHANGED
@@ -22,22 +22,26 @@ app = FastAPI(middleware=middleware)
|
|
22 |
|
23 |
files_dir = os.path.expanduser("~/wtp_be_files/")
|
24 |
session_assistant = ChatPDF()
|
25 |
-
session_messages = []
|
26 |
|
27 |
@app.get("/query")
|
28 |
def process_input(text: str):
|
29 |
if text and len(text.strip()) > 0:
|
30 |
text = text.strip()
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
@app.post("/upload")
|
38 |
def upload(files: list[UploadFile]):
|
39 |
session_assistant.clear()
|
40 |
-
session_messages = []
|
41 |
|
42 |
try:
|
43 |
os.makedirs(files_dir)
|
|
|
22 |
|
23 |
files_dir = os.path.expanduser("~/wtp_be_files/")
|
24 |
session_assistant = ChatPDF()
|
25 |
+
# session_messages = []
|
26 |
|
27 |
@app.get("/query")
|
28 |
def process_input(text: str):
|
29 |
if text and len(text.strip()) > 0:
|
30 |
text = text.strip()
|
31 |
+
print("PRINTING STREAM")
|
32 |
+
agent_text_stream = session_assistant.ask(text)
|
33 |
+
print(stream_gen)
|
34 |
+
for text in agent_text_stream:
|
35 |
+
print(text)
|
36 |
+
# session_messages.append((text, True))
|
37 |
+
# session_messages.append((agent_text, False))
|
38 |
+
return "Query resolved!"
|
39 |
|
40 |
|
41 |
@app.post("/upload")
|
42 |
def upload(files: list[UploadFile]):
|
43 |
session_assistant.clear()
|
44 |
+
# session_messages = []
|
45 |
|
46 |
try:
|
47 |
os.makedirs(files_dir)
|
app/rag.py
CHANGED
@@ -151,7 +151,10 @@ class ChatPDF:
|
|
151 |
# response_synthesizer=response_synthesizer,
|
152 |
# )
|
153 |
|
154 |
-
self.query_engine = index.as_query_engine(
|
|
|
|
|
|
|
155 |
|
156 |
def ask(self, query: str):
|
157 |
if not self.query_engine:
|
@@ -159,9 +162,8 @@ class ChatPDF:
|
|
159 |
|
160 |
logger.info("retrieving the response to the query")
|
161 |
# response = self.query_engine.query(str_or_query_bundle=query)
|
162 |
-
|
163 |
-
|
164 |
-
return response
|
165 |
|
166 |
def clear(self):
|
167 |
self.query_engine = None
|
|
|
151 |
# response_synthesizer=response_synthesizer,
|
152 |
# )
|
153 |
|
154 |
+
self.query_engine = index.as_query_engine(
|
155 |
+
streaming=True,
|
156 |
+
# similarity_top_k=6,
|
157 |
+
)
|
158 |
|
159 |
def ask(self, query: str):
|
160 |
if not self.query_engine:
|
|
|
162 |
|
163 |
logger.info("retrieving the response to the query")
|
164 |
# response = self.query_engine.query(str_or_query_bundle=query)
|
165 |
+
streaming_response = self.query_engine.query(query)
|
166 |
+
return streaming_response.response_gen
|
|
|
167 |
|
168 |
def clear(self):
|
169 |
self.query_engine = None
|