Spaces:
Runtime error
Runtime error
import platform | |
import os | |
import time | |
from threading import Thread | |
from rich.text import Text | |
from rich.live import Live | |
from model.infer import ChatBot | |
from config import InferConfig | |
infer_config = InferConfig() | |
chat_bot = ChatBot(infer_config=infer_config) | |
clear_cmd = 'cls' if platform.system().lower() == 'windows' else 'clear' | |
welcome_txt = '欢迎使用ChatBot,输入`exit`退出,输入`cls`清屏。\n' | |
print(welcome_txt) | |
def build_prompt(history: list[list[str]]) -> str: | |
prompt = welcome_txt | |
for query, response in history: | |
prompt += '\n\033[0;33;40m用户:\033[0m{}'.format(query) | |
prompt += '\n\033[0;32;40mChatBot:\033[0m\n{}\n'.format(response) | |
return prompt | |
STOP_CIRCLE: bool=False | |
def circle_print(total_time: int=60) -> None: | |
global STOP_CIRCLE | |
'''非stream chat打印忙碌状态 | |
''' | |
list_circle = ["\\", "|", "/", "—"] | |
for i in range(total_time * 4): | |
time.sleep(0.25) | |
print("\r{}".format(list_circle[i % 4]), end="", flush=True) | |
if STOP_CIRCLE: break | |
print("\r", end='', flush=True) | |
def chat(stream: bool=True) -> None: | |
global STOP_CIRCLE | |
history = [] | |
turn_count = 0 | |
while True: | |
print('\r\033[0;33;40m用户:\033[0m', end='', flush=True) | |
input_txt = input() | |
if len(input_txt) == 0: | |
print('请输入问题') | |
continue | |
# 退出 | |
if input_txt.lower() == 'exit': | |
break | |
# 清屏 | |
if input_txt.lower() == 'cls': | |
history = [] | |
turn_count = 0 | |
os.system(clear_cmd) | |
print(welcome_txt) | |
continue | |
if not stream: | |
STOP_CIRCLE = False | |
thread = Thread(target=circle_print) | |
thread.start() | |
outs = chat_bot.chat(input_txt) | |
STOP_CIRCLE = True | |
thread.join() | |
print("\r\033[0;32;40mChatBot:\033[0m\n{}\n\n".format(outs), end='') | |
continue | |
history.append([input_txt, '']) | |
stream_txt = [] | |
streamer = chat_bot.stream_chat(input_txt) | |
rich_text = Text() | |
print("\r\033[0;32;40mChatBot:\033[0m\n", end='') | |
with Live(rich_text, refresh_per_second=15) as live: | |
for i, word in enumerate(streamer): | |
rich_text.append(word) | |
stream_txt.append(word) | |
stream_txt = ''.join(stream_txt) | |
if len(stream_txt) == 0: | |
stream_txt = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋" | |
history[turn_count][1] = stream_txt | |
os.system(clear_cmd) | |
print(build_prompt(history), flush=True) | |
turn_count += 1 | |
if __name__ == '__main__': | |
chat(stream=True) |