import platform import os import time from threading import Thread from rich.text import Text from rich.live import Live from model.infer import ChatBot from config import InferConfig infer_config = InferConfig() chat_bot = ChatBot(infer_config=infer_config) clear_cmd = 'cls' if platform.system().lower() == 'windows' else 'clear' welcome_txt = '欢迎使用ChatBot,输入`exit`退出,输入`cls`清屏。\n' print(welcome_txt) def build_prompt(history: list[list[str]]) -> str: prompt = welcome_txt for query, response in history: prompt += '\n\033[0;33;40m用户:\033[0m{}'.format(query) prompt += '\n\033[0;32;40mChatBot:\033[0m\n{}\n'.format(response) return prompt STOP_CIRCLE: bool=False def circle_print(total_time: int=60) -> None: global STOP_CIRCLE '''非stream chat打印忙碌状态 ''' list_circle = ["\\", "|", "/", "—"] for i in range(total_time * 4): time.sleep(0.25) print("\r{}".format(list_circle[i % 4]), end="", flush=True) if STOP_CIRCLE: break print("\r", end='', flush=True) def chat(stream: bool=True) -> None: global STOP_CIRCLE history = [] turn_count = 0 while True: print('\r\033[0;33;40m用户:\033[0m', end='', flush=True) input_txt = input() if len(input_txt) == 0: print('请输入问题') continue # 退出 if input_txt.lower() == 'exit': break # 清屏 if input_txt.lower() == 'cls': history = [] turn_count = 0 os.system(clear_cmd) print(welcome_txt) continue if not stream: STOP_CIRCLE = False thread = Thread(target=circle_print) thread.start() outs = chat_bot.chat(input_txt) STOP_CIRCLE = True thread.join() print("\r\033[0;32;40mChatBot:\033[0m\n{}\n\n".format(outs), end='') continue history.append([input_txt, '']) stream_txt = [] streamer = chat_bot.stream_chat(input_txt) rich_text = Text() print("\r\033[0;32;40mChatBot:\033[0m\n", end='') with Live(rich_text, refresh_per_second=15) as live: for i, word in enumerate(streamer): rich_text.append(word) stream_txt.append(word) stream_txt = ''.join(stream_txt) if len(stream_txt) == 0: stream_txt = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋" history[turn_count][1] = stream_txt os.system(clear_cmd) print(build_prompt(history), flush=True) turn_count += 1 if __name__ == '__main__': chat(stream=True)