Spaces:
Sleeping
Sleeping
suwonpabby
commited on
Commit
·
c861cde
1
Parent(s):
49b0675
Modify
Browse files
app.py
CHANGED
@@ -59,11 +59,11 @@ embedding_model = BGEM3FlagModel(embedding_model_name, use_fp16=True)
|
|
59 |
# Util Functions
|
60 |
|
61 |
# 1개 데이터 처리, 배치 단위 아님
|
62 |
-
def qa_2_str(
|
63 |
result = ""
|
64 |
|
65 |
-
if len(
|
66 |
-
for idx, message in enumerate(
|
67 |
if idx % 2 == 0: # Q
|
68 |
result += f"User: {message}\n"
|
69 |
else: # A
|
@@ -174,11 +174,11 @@ Assistant: {context_example}
|
|
174 |
|
175 |
|
176 |
@spaces.GPU(duration=35)
|
177 |
-
def make_gen(
|
178 |
start_time = time.time()
|
179 |
|
180 |
# Make For Rag Prompt
|
181 |
-
rag_prompt = qa_2_str(
|
182 |
|
183 |
# Do RAG
|
184 |
query_embeddings = embedding_model.encode([rag_prompt],
|
@@ -199,7 +199,7 @@ def make_gen(qa, candidates, top_k, character_type):
|
|
199 |
|
200 |
# Make For LLM Prompt
|
201 |
|
202 |
-
final_prompt = make_prompt(
|
203 |
|
204 |
# Use LLM
|
205 |
streamer = TextIteratorStreamer(llm_tokenizer, skip_special_tokens=True)
|
@@ -239,12 +239,12 @@ def make_gen(qa, candidates, top_k, character_type):
|
|
239 |
|
240 |
|
241 |
@app.get("/")
|
242 |
-
async def root_endpoint(
|
243 |
-
return StreamingResponse(gen_stream(
|
244 |
|
245 |
|
246 |
-
async def gen_stream(
|
247 |
-
for value in make_gen(
|
248 |
yield value
|
249 |
|
250 |
|
|
|
59 |
# Util Functions
|
60 |
|
61 |
# 1개 데이터 처리, 배치 단위 아님
|
62 |
+
def qa_2_str(QA: List) -> str:
|
63 |
result = ""
|
64 |
|
65 |
+
if len(QA) > 1:
|
66 |
+
for idx, message in enumerate(QA[:-1]):
|
67 |
if idx % 2 == 0: # Q
|
68 |
result += f"User: {message}\n"
|
69 |
else: # A
|
|
|
174 |
|
175 |
|
176 |
@spaces.GPU(duration=35)
|
177 |
+
def make_gen(QA, candidates, top_k, character_type):
|
178 |
start_time = time.time()
|
179 |
|
180 |
# Make For Rag Prompt
|
181 |
+
rag_prompt = qa_2_str(QA)
|
182 |
|
183 |
# Do RAG
|
184 |
query_embeddings = embedding_model.encode([rag_prompt],
|
|
|
199 |
|
200 |
# Make For LLM Prompt
|
201 |
|
202 |
+
final_prompt = make_prompt(QA, rag_result, character_type)
|
203 |
|
204 |
# Use LLM
|
205 |
streamer = TextIteratorStreamer(llm_tokenizer, skip_special_tokens=True)
|
|
|
239 |
|
240 |
|
241 |
@app.get("/")
|
242 |
+
async def root_endpoint(QA: List[str] = Query(...), candidates: List[str] = Query(...), top_k: int = Query(...), character_type: int = Query(...)):
|
243 |
+
return StreamingResponse(gen_stream(QA, candidates, top_k, character_type), media_type="text/event-stream")
|
244 |
|
245 |
|
246 |
+
async def gen_stream(QA, candidates, top_k, character_type):
|
247 |
+
for value in make_gen(QA, candidates, top_k, character_type):
|
248 |
yield value
|
249 |
|
250 |
|