|
import os, base64 |
|
import requests, json |
|
import gradio as gr |
|
|
|
|
|
GREEN = '\033[1;32m' |
|
BLUE = '\033[1;34m' |
|
RESET = '\033[0m' |
|
URL = "https://ai1071.4dstaging.com/v1/" |
|
|
|
VALID_ANSWER, QUERY_FAIL, INVALID_ANSWER=0 , 1, 2 |
|
|
|
VICTORIA_HARBOUR, MIC =0, 1 |
|
CUSTOMER = MIC |
|
|
|
MODES = [ |
|
{ |
|
"name": " ", |
|
"query_mode_indx": 5, |
|
"retrieval_temperature": 0.2, |
|
"path": r"E:\workspace\RAG_data\20240412_superQuery\db\EC_test_all\20240603_haigang_qa", |
|
"sample_questions": [ |
|
"這裡可以book位嗎?", "可以book位嗎?", "Hi", "蟹", "魚", "會員", "訂枱" |
|
"锡我?", "可唔可以幫我寫一張菜單?", |
|
"可以加大長腳蟹嗎?","想查詢最新堂食優惠", |
|
"有什麼優惠", "宴會菜單", "有長腳蟹?", "積分如何運作?", "點加入會員?", |
|
"套餐可轉其他菜式嗎?", "網購限定優惠可以堂食嗎?", "當日海鮮供應情況?" |
|
], |
|
|
|
},{ |
|
"name": "MiC Modular Integrated Construction - HK (Beta)", |
|
"query_mode_indx": 4, |
|
"retrieval_temperature": 0.2, |
|
|
|
"path": r"E:\workspace\RAG_data\20240412_superQuery\db\EC_test_all\20240619_mic_demo", |
|
"sample_questions": [ |
|
"What is MIC?", "優惠措施", "Please introduce CIC", "Key Technologies of MIC", |
|
"組裝合成建築法", "物料或產品規格", "MIC safety." |
|
], |
|
} |
|
] |
|
|
|
|
|
questions=MODES[CUSTOMER]['sample_questions'] |
|
|
|
def the_answer(response:dict): |
|
a=response['msg'].split('Answer(GPT4):')[1].split('References:')[0] |
|
a.strip() |
|
return a |
|
|
|
def the_references(response:dict, user_query: str): |
|
ref_contents=[] |
|
if response["code"]==VALID_ANSWER: |
|
for ref in response["data"]["source_docs"]: |
|
content=ref["page_content"] |
|
|
|
|
|
ref_contents.append(content) |
|
ref_contents_filtered = filter_repeated(user_query, ref_contents) |
|
return ref_contents_filtered |
|
|
|
def filter_repeated(user_query, ref_contents: list): |
|
|
|
|
|
ref_contents_filtered = [] |
|
for ref in ref_contents: |
|
|
|
try: |
|
question = next(iter(ref.get('問題').values())) |
|
except StopIteration as e: |
|
print(e) |
|
pass |
|
except Exception as e: |
|
print(e) |
|
ref_contents_filtered.append(ref) |
|
continue |
|
print(question) |
|
print("question == user_query: "+str(question == user_query)) |
|
if not question == user_query: |
|
ref_contents_filtered.append(ref) |
|
return ref_contents_filtered |
|
|
|
def get_images_from_source(source_docs): |
|
image_exts = [".jpg", ".jpeg", ".png"] |
|
source_list = [doc['source'] for doc in source_docs] |
|
source_img_list = [source for source in source_list if os.path.splitext(source)[1] in image_exts] |
|
|
|
buffer_img_str = "" |
|
for source in source_img_list: |
|
response = requests.get(URL+f"images?image_id={source}") |
|
if response.status_code == 200: |
|
image_data = response.content |
|
base64_image = base64.b64encode(image_data).decode("utf-8") |
|
|
|
|
|
|
|
img_html = f'<img src="data:image/png;base64,{base64_image}" alt="img_name">' |
|
buffer_img_str += "\n"+img_html |
|
|
|
|
|
|
|
else: |
|
print("Error fetching image") |
|
return buffer_img_str |
|
|
|
def all_info(response): |
|
info="\n".join([f"{GREEN}{key}{RESET}: {value}" for key, value in response.items()]) |
|
return info |
|
|
|
def request_stream_chat(question:str, history): |
|
global temp_source_docs |
|
|
|
if not question: |
|
yield "Hello! What would you like to know?" |
|
return |
|
|
|
payload = { |
|
"prompt": question, |
|
"retrieval_temperature": 0.2, |
|
|
|
|
|
|
|
|
|
"query_mode_indx": MODES[CUSTOMER]['query_mode_indx'], |
|
"path": MODES[CUSTOMER]['path'], |
|
|
|
"stream": True, |
|
"LLM_type": "gpt" |
|
} |
|
reply_buffer = "" |
|
with requests.post(url=URL+"query", json=payload, stream=True) as r_stream: |
|
for line in r_stream.iter_lines(): |
|
if line: |
|
line = json.loads(line) |
|
if line['finished']: |
|
response = line |
|
|
|
|
|
|
|
msg = response['msg'] |
|
|
|
if payload['query_mode_indx'] == 5: |
|
source_docs_content = the_references(response, question) |
|
source_docs_content_str = "\n".join([str(content) for content in source_docs_content]) |
|
response_str = msg+"\n\nSource documents:\n"+source_docs_content_str |
|
else: |
|
response_str = msg+"\n\n"+response.get('reference') |
|
|
|
source_docs = response['data']['source_docs'] |
|
image_str = get_images_from_source(source_docs) |
|
response_str += "\n"+image_str |
|
yield response_str |
|
|
|
break |
|
else: |
|
|
|
|
|
reply_buffer += line['reply'] |
|
yield reply_buffer |
|
|
|
|
|
|
|
|
|
def my_generator(x): |
|
for i in range(x): |
|
yield i |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.ChatInterface( |
|
request_stream_chat, |
|
examples=questions, |
|
|
|
chatbot=gr.Chatbot(height=450), |
|
textbox=gr.Textbox(placeholder="喺呢度問我問題.", container=False, scale=7), |
|
title=MODES[CUSTOMER]['name'], |
|
description="智能査詢", |
|
theme="soft", |
|
cache_examples=False, |
|
retry_btn=None, |
|
undo_btn="Delete Previous", |
|
clear_btn="Clear", |
|
fill_height=True, |
|
).launch(share=True) |