suwonpabby commited on
Commit
c861cde
·
1 Parent(s): 49b0675
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -59,11 +59,11 @@ embedding_model = BGEM3FlagModel(embedding_model_name, use_fp16=True)
59
  # Util Functions
60
 
61
  # 1개 데이터 처리, 배치 단위 아님
62
- def qa_2_str(qa: List) -> str:
63
  result = ""
64
 
65
- if len(qa) > 1:
66
- for idx, message in enumerate(qa[:-1]):
67
  if idx % 2 == 0: # Q
68
  result += f"User: {message}\n"
69
  else: # A
@@ -174,11 +174,11 @@ Assistant: {context_example}
174
 
175
 
176
  @spaces.GPU(duration=35)
177
- def make_gen(qa, candidates, top_k, character_type):
178
  start_time = time.time()
179
 
180
  # Make For Rag Prompt
181
- rag_prompt = qa_2_str(qa)
182
 
183
  # Do RAG
184
  query_embeddings = embedding_model.encode([rag_prompt],
@@ -199,7 +199,7 @@ def make_gen(qa, candidates, top_k, character_type):
199
 
200
  # Make For LLM Prompt
201
 
202
- final_prompt = make_prompt(qa, rag_result, character_type)
203
 
204
  # Use LLM
205
  streamer = TextIteratorStreamer(llm_tokenizer, skip_special_tokens=True)
@@ -239,12 +239,12 @@ def make_gen(qa, candidates, top_k, character_type):
239
 
240
 
241
  @app.get("/")
242
- async def root_endpoint(qa: List[str], candidates: List[str] = Query(...), top_k: int = Query(...), character_type: int = Query(...)):
243
- return StreamingResponse(gen_stream(qa, candidates, top_k, character_type), media_type="text/event-stream")
244
 
245
 
246
- async def gen_stream(qa, candidates, top_k, character_type):
247
- for value in make_gen(qa, candidates, top_k, character_type):
248
  yield value
249
 
250
 
 
59
  # Util Functions
60
 
61
  # 1개 데이터 처리, 배치 단위 아님
62
+ def qa_2_str(QA: List) -> str:
63
  result = ""
64
 
65
+ if len(QA) > 1:
66
+ for idx, message in enumerate(QA[:-1]):
67
  if idx % 2 == 0: # Q
68
  result += f"User: {message}\n"
69
  else: # A
 
174
 
175
 
176
  @spaces.GPU(duration=35)
177
+ def make_gen(QA, candidates, top_k, character_type):
178
  start_time = time.time()
179
 
180
  # Make For Rag Prompt
181
+ rag_prompt = qa_2_str(QA)
182
 
183
  # Do RAG
184
  query_embeddings = embedding_model.encode([rag_prompt],
 
199
 
200
  # Make For LLM Prompt
201
 
202
+ final_prompt = make_prompt(QA, rag_result, character_type)
203
 
204
  # Use LLM
205
  streamer = TextIteratorStreamer(llm_tokenizer, skip_special_tokens=True)
 
239
 
240
 
241
  @app.get("/")
242
+ async def root_endpoint(QA: List[str] = Query(...), candidates: List[str] = Query(...), top_k: int = Query(...), character_type: int = Query(...)):
243
+ return StreamingResponse(gen_stream(QA, candidates, top_k, character_type), media_type="text/event-stream")
244
 
245
 
246
+ async def gen_stream(QA, candidates, top_k, character_type):
247
+ for value in make_gen(QA, candidates, top_k, character_type):
248
  yield value
249
 
250