|
import re |
|
import os |
|
import unicodedata |
|
from typing import List |
|
import uuid |
|
import hashlib |
|
|
|
import pandas as pd |
|
|
|
from common.call_llm import chat_stream_generator |
|
|
|
DOC_QA_ENDPOINT = os.environ.get("DOC_QA_ENDPOINT") |
|
|
|
prompt_template = """你是由猎户星空开发的AI助手,你的名字叫聚言。你可以根据下面给出的参考资料和聊天历史来回答用户问题。 |
|
|
|
### 参考资料 ### |
|
{context} |
|
|
|
### 聊天历史 ### |
|
{chat_history} |
|
|
|
### 用户问题 ### |
|
{question} |
|
|
|
### 回答要求 ### |
|
{requirement} |
|
""" |
|
|
|
|
|
def document_prompt_template(): |
|
return """["Source_id": {doc_id},"Content": "{page_content}"]""" |
|
|
|
|
|
|
|
def language_detect(text: str) -> str: |
|
text = re.sub(r"([ ■◼•*…— �●⚫]+|[·\.~•、—'}\n\t]{1,})", '', text.strip()) |
|
stats = { |
|
"zh": 0, |
|
"ja": 0, |
|
"ko": 0, |
|
"en": 0, |
|
"th": 0, |
|
"other": 0 |
|
} |
|
char_count = 0 |
|
for char in text: |
|
try: |
|
code_name = unicodedata.name(char) |
|
except Exception: |
|
continue |
|
char_count += 1 |
|
|
|
if 'CJK' in code_name: |
|
stats["zh"] += 1 |
|
|
|
elif 'HIRAGANA' in code_name or 'KATAKANA' in code_name: |
|
stats["ja"] += 1 |
|
|
|
elif "THAI" in code_name: |
|
stats["th"] += 1 |
|
|
|
elif 'HANGUL' in code_name: |
|
stats["ko"] += 1 |
|
|
|
elif 'LA' in code_name: |
|
stats["en"] += 1 |
|
else: |
|
stats["other"] += 1 |
|
|
|
lang = "" |
|
ratio = 0.0 |
|
for lan in stats: |
|
if lan == "other": |
|
continue |
|
|
|
if lan == "en": |
|
stats[lan] /= 4.0 |
|
lan_r = float(stats[lan]) / char_count |
|
if ratio < lan_r: |
|
lang = lan |
|
ratio = lan_r |
|
|
|
return lang |
|
|
|
|
|
def language_prompt(lan: str) -> str: |
|
_ZH_LANGUAGE_MAP = { |
|
"zh": "中文", |
|
"en": "英文", |
|
"other": "中文", |
|
"ja": "中文", |
|
"zh_gd": "中文", |
|
"ko": "韩文", |
|
"th": "泰文" |
|
} |
|
return _ZH_LANGUAGE_MAP.get(lan.lower(), "中文") |
|
|
|
|
|
def _get_chat_history(chat_history: List[List]) -> str: |
|
if not chat_history: |
|
return "" |
|
chat_history_text = "" |
|
for human_msg, ai_msg in chat_history: |
|
human = "{'Human': '" + human_msg + "'}" |
|
ai = "{'AI': '" + ai_msg + "'}" |
|
chat_history_text += "[" + ", ".join([human, ai]) + "]\n" |
|
return chat_history_text |
|
|
|
|
|
def get_prompt(context: str, chat_history: str, question: str, trapped_switch: int, fallback: str, |
|
citations_switch: int) -> str: |
|
answer_prompts = ["1. 你只能根据上面参考资料中给出的事实信息来回答用户问题,不要胡编乱造。", |
|
"2. 如果向用户提出澄清问题有助于回答问题,可以尝试提问。"] |
|
index = 3 |
|
if len(fallback) > 0 and trapped_switch == 1: |
|
answer_prompts.append( |
|
str(index) + ". " + """如果参考资料中的信息不足以回答用户问题,请直接回答下面三个双引号中的内容:\"\"\"{fallback}\"\"\"。""".format( |
|
fallback=fallback)) |
|
index += 1 |
|
|
|
if citations_switch: |
|
citation_prompt = "如果你给出的答案里引用了参考资料中的内容,请在答案的结尾处添加你引用的Source_id,引用的Source_id值来自于参考资料中,并用两个方括号括起来。示例:[[d97b811489b73f46c8d2cb1bc888dbbe]]、[[b6be48868de736b90363d001c092c019]]" |
|
answer_prompts.append(str(index) + ". " + citation_prompt) |
|
index += 1 |
|
|
|
lan = language_detect(question) |
|
style_prompt = """请你以第一人称并且用严谨的风格来回答问题,一定要用{language}来回答,并且基于事实详细阐述。""".format( |
|
language=language_prompt(lan), |
|
) |
|
answer_prompts.append(str(index) + ". " + style_prompt) |
|
answer_prompts = "\n".join(answer_prompts) |
|
prompt = prompt_template.format(context=context, chat_history=chat_history, question=question, |
|
requirement=answer_prompts) |
|
return prompt |
|
|
|
|
|
def generate_doc_qa(input_text: str, history: List[List[str]], doc_df: "pd.DataFrame", trapped_switch: str, fallback: str, |
|
citations_switch: str): |
|
"""Generates chat responses according to the input text, history and page content.""" |
|
|
|
print(f"input_text: {input_text}, history: {history}, page_content: {doc_df}, trapped_switch: {trapped_switch}, fallback: {fallback}, citations_switch: {citations_switch}") |
|
|
|
citations_switch = 1 if citations_switch == "开启引用" else 0 |
|
trapped_switch = 1 if trapped_switch == "自定义话术" else 0 |
|
fallback = fallback or "" |
|
|
|
input_text = input_text or "你好" |
|
history = (history or [])[-5:] |
|
|
|
doc_df = doc_df[doc_df["文档片段内容"].notna()] |
|
|
|
context = "" |
|
source_id_map = dict() |
|
for _, row in doc_df.iterrows(): |
|
if not row["文档片段内容"] or not row["文档片段名称"]: |
|
continue |
|
source_id = hashlib.md5(str(uuid.uuid4()).encode("utf-8")).hexdigest() |
|
source_id_map[source_id] = row["文档片段名称"] |
|
context += document_prompt_template().format(doc_id=source_id, page_content=row["文档片段内容"]) + "\n\n" |
|
|
|
prompt = get_prompt(context.strip(), _get_chat_history(history), input_text, trapped_switch, fallback, |
|
citations_switch) |
|
print(f"docQA prompt: {prompt}") |
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
stream_response = chat_stream_generator(messages=messages, endpoint=DOC_QA_ENDPOINT) |
|
|
|
cache = "" |
|
|
|
for character in stream_response: |
|
if "[" in character or cache: |
|
cache += character |
|
continue |
|
history[-1][1] += character |
|
yield None, history |
|
|
|
if cache: |
|
source_ids = re.findall(r"\[\[(.*?)\]\]", cache) |
|
print(f"Matched source ids {source_ids}") |
|
for source_id in source_ids: |
|
origin_source_id = source_id_map.get(source_id, source_id) |
|
cache = cache.replace(source_id, origin_source_id) |
|
|
|
history[-1][1] += cache |
|
yield None, history |