|
import re |
|
from typing import List |
|
from loguru import logger |
|
|
|
import config |
|
from llm.call_llm import get_completion, get_completion_from_messages |
|
from words_db import words_db |
|
from create_db import get_similar_k_words |
|
from prompts import trans_prompt, query_prompt, learn_prompt |
|
from prompts import system_message_mapper |
|
|
|
|
|
def format_common_prompt(raw_prompt, variable): |
|
"""get format prompt by repalce variable in raw_prompt |
|
""" |
|
return raw_prompt.format(variable) |
|
|
|
def format_chat_prompt(message, chat_history) -> str: |
|
"""get format prompt |
|
""" |
|
prompt = "" |
|
for turn in chat_history: |
|
user_message, bot_message = turn |
|
prompt = f"{prompt}\nUser: {user_message}\nAssistant: {bot_message}" |
|
prompt = f"{prompt}\nUser: {message}\nAssistant:" |
|
return prompt |
|
|
|
|
|
def respond(message, chat_history, |
|
llm="gpt-3.5-turbo", history_len=3, temperature=0.1, max_tokens=2048): |
|
"""get respond from LLM |
|
""" |
|
|
|
|
|
respond_message = command_parser(message) |
|
if respond_message: |
|
chat_history.append((message, respond_message)) |
|
respond_message = "" |
|
return respond_message, chat_history |
|
|
|
|
|
respond_message = command_mapper(message) |
|
if respond_message: |
|
chat_history.append((message, respond_message)) |
|
respond_message = "" |
|
return respond_message, chat_history |
|
|
|
|
|
if message is None or len(message) < 1: |
|
return "", chat_history |
|
try: |
|
chat_history = chat_history[-history_len:] if history_len > 0 else [] |
|
formatted_prompt = format_chat_prompt(message, chat_history) |
|
bot_message = get_completion( |
|
formatted_prompt, |
|
llm, |
|
api_key=config.api_key, |
|
temperature=temperature, max_tokens=max_tokens) |
|
bot_message = re.sub(r"\\n", '<br/>', bot_message) |
|
chat_history.append((message, bot_message)) |
|
return "", chat_history |
|
except Exception as e: |
|
return e, chat_history |
|
|
|
def command_parser(input: str) -> str: |
|
"""parse 4 type commands |
|
1. :add |
|
2. :remove |
|
3. :learn |
|
4. :query |
|
return info of action to user |
|
""" |
|
if input.startswith(":add"): |
|
words = input.split(" ")[1:] |
|
info = add_words(words) |
|
return info |
|
if input.startswith(":remove"): |
|
words = input.split(" ")[1:] |
|
info = remove_words(words) |
|
return info |
|
if input.startswith(":learn"): |
|
if len(input.split(" ")) != 2: |
|
return "学习模式将基于词库进行,请指定一个query单词" |
|
query = input.split(" ")[1] |
|
info = learn_words(query) |
|
return f"Based on your query word: {query} and dictionary, learning sentence is:\n{info}" |
|
if input.startswith(":query"): |
|
if len(input.split(" ")) > 2: |
|
return "查询模式仅支持单个单词,请使用:query <word>进行查询" |
|
word = input.split(" ")[1] |
|
info = query_word(word) |
|
return f"{word}\n{info}" |
|
if input.startswith(":show"): |
|
info = show_all_words() |
|
return info |
|
if input.startswith(":help"): |
|
return "目前支持的指令有:\n:add <word1> <word2> ...\n:remove <word1> <word2> ...\n:learn <query_word>\n:query <word>" |
|
|
|
return "" |
|
|
|
|
|
def show_all_words() -> str: |
|
"""show all words in db |
|
""" |
|
try: |
|
all_words = words_db.query_word() |
|
return f"目前词库中的所有单词:\n{all_words}" |
|
except Exception as e: |
|
logger.error(str(e)) |
|
return "查询失败" |
|
|
|
def add_words(input: List[str]): |
|
word_tuple_list = [ |
|
(word, get_completion( |
|
prompt=format_common_prompt(trans_prompt, word), |
|
api_key=config.api_key)) |
|
for word in input |
|
] |
|
try: |
|
for word_tuple in word_tuple_list: |
|
word, definition = word_tuple |
|
words_db.add_word(word, definition) |
|
logger.info(f"已经添加单词: {word} 和其释义: {definition}") |
|
except Exception as e: |
|
logger.error(str(e)) |
|
return f"添加单词失败: {input}" |
|
return f"已添加单词: {input}" |
|
|
|
def remove_words(input: List[str]): |
|
try: |
|
for word in input: |
|
words_db.delete_word(word) |
|
logger.info(f"已经删除单词: {word} 和其释义") |
|
except Exception as e: |
|
logger.error(str(e)) |
|
return f"删除单词失败: {input}" |
|
return f"已删除单词: {input}" |
|
|
|
def learn_words(query_word) -> str: |
|
|
|
words = get_similar_k_words(query_word) |
|
respond = get_completion( |
|
prompt=format_common_prompt(learn_prompt, words), |
|
api_key=config.api_key) |
|
logger.info(f"进入学习模式,学习下列单词: {words}") |
|
return respond |
|
|
|
def query_word(input: str) -> str: |
|
|
|
respond = get_completion( |
|
prompt=format_common_prompt(query_prompt, input), |
|
api_key=config.api_key) |
|
logger.info(f"查询单词: {input}") |
|
return respond |
|
|
|
def command_mapper(input: str) -> str: |
|
"""map natural language to command, return command function |
|
""" |
|
user_message = input |
|
messages = [ |
|
{'role':'system', |
|
'content': system_message_mapper}, |
|
{'role':'user', |
|
'content': f"{user_message}"}, |
|
] |
|
respond = get_completion_from_messages(messages, api_key=config) |
|
mapped_command = command_parser(respond) |
|
logger.info(f"用户输入: {user_message}\n指令解析器输出: {mapped_command}") |
|
|
|
return mapped_command |
|
|
|
|